Skip to content

Instantly share code, notes, and snippets.

@Frityet
Last active February 3, 2026 17:41
Show Gist options
  • Select an option

  • Save Frityet/265b7ea18da8a93e3566be88b17a78c9 to your computer and use it in GitHub Desktop.

Select an option

Save Frityet/265b7ea18da8a93e3566be88b17a78c9 to your computer and use it in GitHub Desktop.
local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local table = _tl_compat and _tl_compat.table or table
local Order = { Context = {}, Transition = {}, Options = {}, DSL = {} }; local __fsm_noop_action = function(context, _event, _from, _to) return context end; local function __fsm_noop_invalid(_ctx, _event, _from) return false end; local __fsm_noop_transition = function(context, _event, _from, _to) return context end; local function __fsm_noop_trace(_msg) end; local function __fsm_normalize_transition(t) t.name = t.name or (tostring(t.from) .. ":" .. tostring(t.event) .. "->" .. tostring(t.to)); return t end; local function __fsm_build_transitions(list) local out = {}; for _, t in ipairs(list) do table.insert(out, __fsm_normalize_transition(t)) end; return out end; local function __fsm_collect_states_events(list) local states = {}; local events = {}; local seen_states = {}; local seen_events = {}; for _, t in ipairs(list) do seen_states[t.from] = seen_states[t.from] or (table.insert(states, t.from) or true); seen_states[t.to] = seen_states[t.to] or (table.insert(states, t.to) or true); seen_events[t.event] = seen_events[t.event] or (table.insert(events, t.event) or true) end; return states, events end; function Order.transition(from, event, to, guard, action, name) local t = { from = from, event = event, to = to, guard = guard, action = action, name = name }; return t end; local function __fsm_find(self, event) local from = self.state; local found = nil; for _, raw in ipairs(self.transitions) do local t = raw; local guard = t.guard; local ok = (t.from == from) and (t.event == event) and ((guard and (guard)(self.ctx, event, from, t.to)) or true); found = found or (ok and t or nil) end; return found end; function Order.new(context, initial_state, opts) local cfg = { name = "Order", strict = false, initial = nil, transitions = {}, on_invalid = __fsm_noop_invalid, on_transition = __fsm_noop_transition, on_enter = __fsm_noop_action, on_exit = __fsm_noop_action, trace = __fsm_noop_trace, states = {}, events = {} }; do cfg.name = "Order"; cfg.strict = true; cfg.initial = "Cart"; cfg.on_invalid = function(ctx, ev, st) table.insert(ctx.history, "invalid:" .. tostring(st) .. ":" .. tostring(ev)); return false end; cfg.on_transition = function(ctx, ev, from, to) table.insert(ctx.history, tostring(from) .. "->" .. tostring(to) .. " via " .. tostring(ev)); return ctx end; cfg.trace = function(msg) print(msg) end; table.insert(cfg.transitions, Order.transition("Cart", "Checkout", "Checkout", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Checkout", "Pay", "Paid", function(ctx) return ctx.fraud_score < 0.7 end, function(ctx) ctx.payment_captured = true; table.insert(ctx.history, "payment_captured"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Checkout", "Cancel", "Canceled", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Paid", "Pack", "Packed", function(ctx) return ctx.inventory_reserved end, function(ctx) table.insert(ctx.history, "packed"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Packed", "Ship", "Shipped", nil, function(ctx) ctx.shipment_id = "SHIP-" .. tostring(#ctx.history + 1); table.insert(ctx.history, "shipped:" .. ctx.shipment_id); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Shipped", "Deliver", "Delivered", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Delivered", "Return", "Refunded", nil, function(ctx) ctx.refund_issued = true; table.insert(ctx.history, "return_refund"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Paid", "Cancel", "Refunded", nil, function(ctx) ctx.refund_issued = true; table.insert(ctx.history, "cancel_refund"); return ctx end, nil)) end; local raw_transitions = cfg.transitions; assert(raw_transitions, "fsm: missing transitions for " .. "Order"); local transitions = __fsm_build_transitions(raw_transitions); local name = (opts and opts.name) or cfg.name or "Order"; local strict = (opts and opts.strict) or cfg.strict or false; local on_invalid = (opts and opts.on_invalid) or cfg.on_invalid or __fsm_noop_invalid; local on_transition = (opts and opts.on_transition) or cfg.on_transition or __fsm_noop_transition; local on_enter = (opts and opts.on_enter) or cfg.on_enter or __fsm_noop_action; local on_exit = (opts and opts.on_exit) or cfg.on_exit or __fsm_noop_action; local trace = (opts and opts.trace) or cfg.trace or __fsm_noop_trace; local states = ((opts and opts.states) or (cfg.states or {})); local events = ((opts and opts.events) or (cfg.events or {})); local auto_states, auto_events = __fsm_collect_states_events(transitions); states = ((#states == 0) and (auto_states)) or (states); events = ((#events == 0) and (auto_events)) or (events); local init = initial_state or cfg.initial; assert(init, "fsm: missing initial state for " .. "Order"); local init_state = init; local self = { state = init_state, ctx = context, transitions = transitions, opts = { name = name, strict = strict, on_invalid = on_invalid, on_transition = on_transition, on_enter = on_enter, on_exit = on_exit, trace = trace, states = states, events = events }, states = states, events = events }; return setmetatable(self, { __index = Order }) end; function Order:add_transition(t) table.insert(self.transitions, __fsm_normalize_transition(t)) end; function Order:state_is(s) return self.state == s end; function Order:can(event) return __fsm_find(self, event) ~= nil end; local function __fsm_apply_transition(self, event, t) local from = self.state; local current_ctx = self.ctx; local action = t.action; current_ctx = (self.opts.on_exit)(current_ctx, event, from, t.to); current_ctx = (action and (action)(current_ctx, event, from, t.to)) or current_ctx; self.state = t.to; current_ctx = (self.opts.on_enter)(current_ctx, event, from, t.to); current_ctx = (self.opts.on_transition)(current_ctx, event, from, t.to); self.ctx = current_ctx; (self.opts.trace)("fsm: " .. tostring(from) .. " --(" .. tostring(event) .. ")-> " .. tostring(self.state)); return true end; local function __fsm_invalid(self, event) local ok = (self.opts.on_invalid)(self.ctx, event, self.state); assert(not self.opts.strict, "fsm: invalid transition (" .. tostring(self.state) .. ", " .. tostring(event) .. ") in " .. tostring(self.opts.name)); return ok end; function Order:step(event) local t = __fsm_find(self, event); return (t and __fsm_apply_transition(self, event, t)) or __fsm_invalid(self, event) end; function Order:dispatch(event) self:step(event); return self.state end; function Order:reset(ctx, state) self.ctx = ctx or self.ctx; self.state = state or self.state end
local ctx = {
inventory_reserved = true,
payment_captured = false,
fraud_score = 0.2,
shipment_id = nil,
refund_issued = false,
history = {},
}
local order = Order.new(ctx)
print("initial", order.state)
local function send(ev)
print("event", ev, "from", order.state)
order:step(ev)
print("state", order.state)
end
send("Checkout")
send("Pay")
send("Pack")
send("Ship")
send("Deliver")
send("Return")
print("history", table.concat(order.ctx.history, " | "))
-- files/src/fsm.tl
-- Fully type safe finite state machine implemented with macros
-------------------------------------------------------------------------------
local macro fsm!(spec: Statement, cfg: Expression)
local BI = BLOCK_INDEXES
local function as_statement(b: Block): Block
if b.kind == "statements" then
if #b ~= 1 then
error("fsm!: expected a single statement")
end
return b[1]
end
return b
end
local function mkstr(s: string): Block
local st = block("string")
st.tk = string.format("%q", s)
return st
end
local function mkident(nm: string): Block
local id = block("identifier")
id.tk = nm
return id
end
local function mkvar(nm: string): Block
local v = block("variable")
v.tk = nm
return v
end
local function mknominal(nm: string): Block
local nt = block("nominal_type")
nt.tk = nm
nt[BI.NOMINAL_TYPE.NAME] = mkident(nm)
return nt
end
local function mkarray(elem: Block): Block
local at = block("array_type")
at[BI.ARRAY_TYPE.ELEMENT] = elem
return at
end
local function mkfield(name: string, typ: Block): Block
local fld = block("record_field")
fld[BI.RECORD_FIELD.NAME] = mkvar(name)
fld[BI.RECORD_FIELD.TYPE] = typ
return fld
end
local function stmt1(b: Block): Block
if b.kind == "statements" then
return b[1]
end
return b
end
local function unquote_str(tk: string): string
local q = tk:sub(1, 1)
if (q == "\"" or q == "'") and tk:sub(-1) == q then
return tk:sub(2, -2)
end
if tk:sub(1, 2) == "[[" and tk:sub(-2) == "]]" then
return tk:sub(3, -3)
end
return tk
end
local function unwrap_paren(exp: Block): Block
local cur = exp
while cur and cur.kind == "paren" and cur[BI.PAREN.EXP] do
cur = cur[BI.PAREN.EXP]
end
return cur
end
local function is_ident(exp: Block): boolean
return exp.kind == "identifier" or exp.kind == "variable"
end
local function normalize_symbol(exp: Block): Block
local e = unwrap_paren(exp)
if is_ident(e) then
return mkstr(e.tk)
end
return exp
end
local function as_literal_table(exp: Block): Block | nil
local e = unwrap_paren(exp)
if e and e.kind == "literal_table" then
return e
end
return nil
end
local function call_name(exp: Block): string | nil
if exp.kind ~= "op_funcall" then
return nil
end
local callee = exp[BI.OP.E1]
if callee and is_ident(callee) then
return callee.tk
end
return nil
end
local function call_arg1(exp: Block): Block | nil
if exp.kind ~= "op_funcall" then
return nil
end
local args = exp[BI.OP.E2]
if args and args.kind == "expression_list" then
return args[BI.EXPRESSION_LIST.FIRST]
end
return nil
end
local function unwrap_alias_call(exp: Block, aliases: {string: boolean}): Block
local e = unwrap_paren(exp)
if e.kind == "op_funcall" then
local callee = e[BI.OP.E1]
if callee and is_ident(callee) and aliases[callee.tk] then
local args = e[BI.OP.E2]
if args and args.kind == "expression_list" then
local arg1 = args[BI.EXPRESSION_LIST.FIRST]
if arg1 then
return arg1
end
end
end
end
return exp
end
local function split_guard_action(exp: Block): (Block, Block | nil, Block | nil)
local base = unwrap_paren(exp)
local action: Block | nil = nil
local guard: Block | nil = nil
if base.kind == "op_bor" then
action = unwrap_alias_call(base[BI.OP.E2], { ["do"] = true, act = true, action = true })
base = unwrap_paren(base[BI.OP.E1])
end
if base.kind == "op_band" then
guard = unwrap_alias_call(base[BI.OP.E2], { when = true, guard = true, ["if"] = true })
base = unwrap_paren(base[BI.OP.E1])
end
return base, guard, action
end
local function mktable_list(items: {Block}, normalize_fn: function(Block): Block): Block
local out = block("literal_table")
local n = 1
for _, v in ipairs(items) do
local it = block("literal_table_item")
local key = block("integer")
key.tk = tostring(n)
it[BI.LITERAL_TABLE_ITEM.KEY] = key
it[BI.LITERAL_TABLE_ITEM.VALUE] = normalize_fn(v)
table.insert(out, it)
n = n + 1
end
return out
end
local function normalize_list_table(tbl: Block, normalize_fn: function(Block): Block): Block
local out = block("literal_table")
local n = 1
for _, item in ipairs(tbl) do
local v = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE]
if v then
local it = block("literal_table_item")
local key = block("integer")
key.tk = tostring(n)
it[BI.LITERAL_TABLE_ITEM.KEY] = key
it[BI.LITERAL_TABLE_ITEM.VALUE] = normalize_fn(v)
table.insert(out, it)
n = n + 1
end
end
return out
end
local st = as_statement(spec)
if not (st and st.kind == "local_type") then
error("fsm!: expected a local record declaration")
end
local name_block = st[BI.LOCAL_TYPE.VAR]
if not (name_block and name_block.kind == "identifier") then
error("fsm!: record declaration must have a name")
end
local tname = name_block
local tname_str = mkstr(name_block.tk)
local record_newtype = st[BI.LOCAL_TYPE.VALUE]
if not (record_newtype and record_newtype.kind == "newtype") then
error("fsm!: expected a record type")
end
local type_decl = record_newtype[BI.NEWTYPE.TYPEDECL]
if not (type_decl and type_decl.kind == "typedecl") then
error("fsm!: expected a record type")
end
local record_def = type_decl[BI.TYPEDECL.TYPE]
if not (record_def and record_def.kind == "record") then
error("fsm!: expected a record type")
end
local body = record_def[BI.RECORD.FIELDS]
if not body then
error("fsm!: record missing body")
end
local function has_nested_type(nm: string): boolean
for _, item in ipairs(body) do
if item.kind == "local_type" then
local v = item[BI.LOCAL_TYPE.VAR]
if v and (v.kind == "identifier" or v.kind == "type_identifier") and v.tk == nm then
return true
end
end
end
return false
end
if not has_nested_type("State") then
error("fsm!: record must define a nested type 'State'")
end
if not has_nested_type("Event") then
error("fsm!: record must define a nested type 'Event'")
end
if not has_nested_type("Context") then
error("fsm!: record must define a nested type 'Context'")
end
local function parse_transition_rune(exp: Block): {string: Block} | nil
local e, guard, action = split_guard_action(exp)
if e.kind == "op_shr" or e.kind == "op_shl" then
local left = unwrap_paren(e[BI.OP.E1])
local to = e[BI.OP.E2]
if left and left.kind == "op_div" then
local from = left[BI.OP.E1]
local event = left[BI.OP.E2]
if from and event and to then
return {
from = normalize_symbol(from),
event = normalize_symbol(event),
to = normalize_symbol(to),
guard = guard,
action = action,
}
end
end
end
return nil
end
local function parse_transition_chain(exp: Block): {string: Block} | nil
local cur = unwrap_paren(exp)
local methods: {any} = {}
while cur and cur.kind == "op_funcall" do
local callee = cur[BI.OP.E1]
if not (callee and (callee.kind == "op_colon" or callee.kind == "op_dot")) then
break
end
local mname = callee[BI.OP.E2]
if not (mname and is_ident(mname)) then
break
end
local args = cur[BI.OP.E2]
local arg1: Block | nil = nil
if args and args.kind == "expression_list" then
arg1 = args[BI.EXPRESSION_LIST.FIRST]
end
table.insert(methods, 1, { name = mname.tk, arg = arg1 })
cur = callee[BI.OP.E1]
end
if #methods == 0 then
return nil
end
local spec: {string: Block} = { from = normalize_symbol(cur) }
for _, m in ipairs(methods) do
if (m.name == "on" or m.name == "via" or m.name == "event") and m.arg then
spec.event = normalize_symbol(m.arg)
elseif (m.name == "to" or m.name == "into" or m.name == "then") and m.arg then
spec.to = normalize_symbol(m.arg)
elseif (m.name == "guard" or m.name == "when" or m.name == "if") and m.arg then
spec.guard = m.arg
elseif (m.name == "do" or m.name == "act" or m.name == "action") and m.arg then
spec.action = m.arg
elseif (m.name == "name" or m.name == "named") and m.arg then
spec.name = normalize_symbol(m.arg)
end
end
if not spec.event then
return nil
end
return spec
end
local function parse_transition_table(tbl: Block): {string: Block} | nil
local spec: {string: Block} = {}
for _, item in ipairs(tbl) do
local key = item[BI.LITERAL_TABLE_ITEM.KEY]
local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE]
if key and key.kind == "string" and val then
local k = unquote_str(key.tk)
if k == "from" then
spec.from = normalize_symbol(val)
elseif k == "event" or k == "on" then
spec.event = normalize_symbol(val)
elseif k == "to" then
spec.to = normalize_symbol(val)
elseif k == "guard" or k == "when" then
spec.guard = unwrap_alias_call(val, { when = true, guard = true, ["if"] = true })
elseif k == "action" or k == "do" then
spec.action = unwrap_alias_call(val, { ["do"] = true, act = true, action = true })
elseif k == "name" then
spec.name = normalize_symbol(val)
end
end
end
if spec.from and spec.event and spec.to then
return spec
end
return nil
end
local function parse_transition_expr(expr: Block, value_opt: Block | nil): {string: Block} | nil
local e = unwrap_paren(expr)
local spec = parse_transition_rune(e)
if not spec then
spec = parse_transition_chain(e)
end
if not spec and e.kind == "literal_table" then
spec = parse_transition_table(e)
end
if spec then
if not spec.to and value_opt then
spec.to = normalize_symbol(value_opt)
end
if spec.guard then
spec.guard = unwrap_alias_call(spec.guard, { when = true, guard = true, ["if"] = true })
end
if spec.action then
spec.action = unwrap_alias_call(spec.action, { ["do"] = true, act = true, action = true })
end
if spec.from and spec.event and spec.to then
return spec
end
end
return nil
end
local function parse_hooks_table(tbl: Block, cfg_info: {string: any})
for _, item in ipairs(tbl) do
local key = item[BI.LITERAL_TABLE_ITEM.KEY]
local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE]
if key and key.kind == "string" and val then
local k = unquote_str(key.tk)
if k == "invalid" or k == "on_invalid" then
cfg_info.on_invalid = val
elseif k == "transition" or k == "on_transition" then
cfg_info.on_transition = val
elseif k == "enter" or k == "on_enter" then
cfg_info.on_enter = val
elseif k == "exit" or k == "on_exit" then
cfg_info.on_exit = val
elseif k == "trace" then
cfg_info.trace = val
end
end
end
end
local function parse_transitions_table(tbl: Block, cfg_info: {string: any})
for _, item in ipairs(tbl) do
local key = item[BI.LITERAL_TABLE_ITEM.KEY]
local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE]
if key and key.kind == "integer" then
if val then
local spec = parse_transition_expr(val, nil)
if spec then
table.insert(cfg_info.transitions as {any}, spec)
end
end
else
if key and val then
local nested = as_literal_table(val)
if nested then
local from = normalize_symbol(key)
for _, sub in ipairs(nested) do
local sk = sub[BI.LITERAL_TABLE_ITEM.KEY]
local sv = sub[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or sub[BI.LITERAL_TABLE_ITEM.VALUE]
if sk and sk.kind == "integer" then
if sv then
local spec = parse_transition_expr(sv, nil)
if spec then
spec.from = spec.from or from
if spec.from and spec.event and spec.to then
table.insert(cfg_info.transitions as {any}, spec)
end
end
end
else
if sk and sv then
local event = normalize_symbol(sk)
local spec = parse_transition_expr(sv, nil)
if not spec then
local base, guard, action = split_guard_action(sv)
spec = {
from = from,
event = event,
to = normalize_symbol(base),
guard = guard,
action = action,
}
else
spec.from = spec.from or from
spec.event = spec.event or event
if not spec.to then
local base, guard, action = split_guard_action(sv)
spec.to = normalize_symbol(base)
spec.guard = spec.guard or guard
spec.action = spec.action or action
end
end
if spec and spec.from and spec.event and spec.to then
table.insert(cfg_info.transitions as {any}, spec)
end
end
end
end
else
local spec = parse_transition_expr(key, val)
if spec then
table.insert(cfg_info.transitions as {any}, spec)
end
end
end
end
end
end
local function parse_cfg_table(tbl: Block, default_name: Block | nil): Block
local cfg_info: {string: any} = {
transitions = {},
state_items = {},
event_items = {},
}
local function set_field(key: string, value: Block)
if key == "name" or key == "title" then
cfg_info.name = normalize_symbol(value)
elseif key == "strict" then
cfg_info.strict = value
elseif key == "initial" or key == "start" then
cfg_info.initial = normalize_symbol(value)
elseif key == "states" then
local list = as_literal_table(value)
cfg_info.states = list and normalize_list_table(list, normalize_symbol) or value
elseif key == "events" then
local list = as_literal_table(value)
cfg_info.events = list and normalize_list_table(list, normalize_symbol) or value
elseif key == "flow" or key == "transitions" then
local flow_tbl = as_literal_table(value)
if flow_tbl then
parse_transitions_table(flow_tbl, cfg_info)
end
elseif key == "hooks" or key == "on" then
local hook_tbl = as_literal_table(value)
if hook_tbl then
parse_hooks_table(hook_tbl, cfg_info)
end
elseif key == "on_invalid" or key == "invalid" then
cfg_info.on_invalid = value
elseif key == "on_transition" or key == "transition" then
cfg_info.on_transition = value
elseif key == "on_enter" or key == "enter" then
cfg_info.on_enter = value
elseif key == "on_exit" or key == "exit" then
cfg_info.on_exit = value
elseif key == "trace" then
cfg_info.trace = value
end
end
local function parse_directive(expr: Block)
local e = unwrap_paren(expr)
local handled = false
if e.kind == "op_funcall" then
local fname = call_name(e)
local arg1 = call_arg1(e)
if fname == "name" or fname == "title" then
if arg1 then cfg_info.name = normalize_symbol(arg1) end
handled = true
elseif fname == "strict" then
cfg_info.strict = arg1 or `true`
handled = true
elseif fname == "initial" or fname == "start" then
if arg1 then cfg_info.initial = normalize_symbol(arg1) end
handled = true
elseif fname == "states" then
local list = arg1 and as_literal_table(arg1)
cfg_info.states = list and normalize_list_table(list, normalize_symbol) or arg1
handled = true
elseif fname == "events" then
local list = arg1 and as_literal_table(arg1)
cfg_info.events = list and normalize_list_table(list, normalize_symbol) or arg1
handled = true
elseif fname == "state" then
if arg1 then table.insert(cfg_info.state_items as {any}, normalize_symbol(arg1)) end
handled = true
elseif fname == "event" then
if arg1 then table.insert(cfg_info.event_items as {any}, normalize_symbol(arg1)) end
handled = true
elseif fname == "flow" or fname == "transitions" then
local flow_tbl = arg1 and as_literal_table(arg1)
if flow_tbl then
parse_transitions_table(flow_tbl, cfg_info)
end
handled = true
elseif fname == "hooks" or fname == "on" then
local hook_tbl = arg1 and as_literal_table(arg1)
if hook_tbl then
parse_hooks_table(hook_tbl, cfg_info)
end
handled = true
elseif fname == "trace" then
if arg1 then cfg_info.trace = arg1 end
handled = true
elseif fname == "on_invalid" then
if arg1 then cfg_info.on_invalid = arg1 end
handled = true
elseif fname == "on_transition" then
if arg1 then cfg_info.on_transition = arg1 end
handled = true
elseif fname == "on_enter" then
if arg1 then cfg_info.on_enter = arg1 end
handled = true
elseif fname == "on_exit" then
if arg1 then cfg_info.on_exit = arg1 end
handled = true
end
elseif is_ident(e) and e.tk == "strict" then
cfg_info.strict = `true`
handled = true
end
if not handled then
local spec = parse_transition_expr(e, nil)
if spec then
table.insert(cfg_info.transitions as {any}, spec)
end
end
end
for _, item in ipairs(tbl) do
local key = item[BI.LITERAL_TABLE_ITEM.KEY]
local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE]
if key and key.kind == "integer" then
if val then
parse_directive(val)
end
elseif key and key.kind == "string" and val then
set_field(unquote_str(key.tk), val)
else
if key and val then
local spec = parse_transition_expr(key, val)
if spec then
table.insert(cfg_info.transitions as {any}, spec)
end
end
end
end
if not cfg_info.name and default_name then
cfg_info.name = default_name
end
local stmts = block("statements")
local function push_stmt(b: Block)
table.insert(stmts, stmt1(b))
end
if cfg_info.name then
local name = cfg_info.name
push_stmt(``` cfg.name = $name ```)
end
if cfg_info.strict then
local strict = cfg_info.strict
push_stmt(``` cfg.strict = $strict ```)
end
if cfg_info.initial then
local initial = cfg_info.initial
push_stmt(``` cfg.initial = $initial ```)
end
if cfg_info.states then
local states = cfg_info.states
push_stmt(``` cfg.states = $states ```)
elseif cfg_info.state_items and #cfg_info.state_items > 0 then
local list = mktable_list(cfg_info.state_items as {Block}, normalize_symbol)
push_stmt(``` cfg.states = $list ```)
end
if cfg_info.events then
local events = cfg_info.events
push_stmt(``` cfg.events = $events ```)
elseif cfg_info.event_items and #cfg_info.event_items > 0 then
local list = mktable_list(cfg_info.event_items as {Block}, normalize_symbol)
push_stmt(``` cfg.events = $list ```)
end
if cfg_info.on_invalid then
local on_invalid = cfg_info.on_invalid
push_stmt(``` cfg.on_invalid = $on_invalid ```)
end
if cfg_info.on_transition then
local on_transition = cfg_info.on_transition
push_stmt(``` cfg.on_transition = $on_transition ```)
end
if cfg_info.on_enter then
local on_enter = cfg_info.on_enter
push_stmt(``` cfg.on_enter = $on_enter ```)
end
if cfg_info.on_exit then
local on_exit = cfg_info.on_exit
push_stmt(``` cfg.on_exit = $on_exit ```)
end
if cfg_info.trace then
local trace = cfg_info.trace
push_stmt(``` cfg.trace = $trace ```)
end
for _, raw in ipairs(cfg_info.transitions) do
local spec = raw
local guard = (spec.guard and spec.guard) or `nil`
local action = (spec.action and spec.action) or `nil`
local name = (spec.name and spec.name) or `nil`
local from = spec.from
local event = spec.event
local to = spec.to
push_stmt(```
table.insert(cfg.transitions, $tname.transition($from, $event, $to, $guard, $action, $name))
```)
end
return stmts
end
local guard_decl = stmt1(```
local type Guard = function(ctx: Context, event?: Event, from?: State, to?: State): boolean
```)
local action_decl = stmt1(```
local type Action = function(ctx: Context, event?: Event, from?: State, to?: State): Context
```)
local transition_decl = stmt1(```
local type Transition = record
from: State
event: Event
to: State
guard: Guard | nil
action: Action | nil
name: string | nil
end
```)
local options_decl = stmt1(```
local type Options = record
name: string | nil
strict: boolean | nil
initial: State | nil
transitions: {any} | nil
on_invalid: function(ctx: Context, event: Event, from: State): boolean | nil
on_transition: function(ctx: Context, event: Event, from: State, to: State): Context | nil
on_enter: Action | nil
on_exit: Action | nil
trace: function(msg: string) | nil
states: {State} | nil
events: {Event} | nil
end
```)
local dsl_decl = stmt1(```
local type DSL = record
name: function(string)
strict: function(boolean)
initial: function(State)
state: function(State)
event: function(Event)
states: function({State})
events: function({Event})
on_invalid: function(function(Context, Event, State): boolean)
on_transition: function(function(Context, Event, State, State): Context)
on_enter: function(Action)
on_exit: function(Action)
trace: function(function(string))
transition: function(State, Event, State, ...: any)
end
```)
table.insert(body, guard_decl)
table.insert(body, action_decl)
table.insert(body, transition_decl)
table.insert(body, options_decl)
table.insert(body, dsl_decl)
table.insert(body, mkfield("state", mknominal("State")))
table.insert(body, mkfield("ctx", mknominal("Context")))
table.insert(body, mkfield("transitions", mkarray(mknominal("any"))))
table.insert(body, mkfield("opts", mknominal("Options")))
table.insert(body, mkfield("states", mkarray(mknominal("State"))))
table.insert(body, mkfield("events", mkarray(mknominal("Event"))))
local cfg_apply: Block
local cfg_do: Block
do
local default_name: Block | nil = nil
local cfg_tbl: Block | nil = nil
local function as_cfg_table(exp: Block): Block | nil
local e = unwrap_paren(exp)
local tbl = as_literal_table(e)
if tbl then
return tbl
end
if e.kind == "op_funcall" then
local arg1 = call_arg1(e)
local arg_tbl = arg1 and as_literal_table(arg1)
if arg_tbl then
local callee = e[BI.OP.E1]
if callee and is_ident(callee) then
default_name = mkstr(callee.tk)
end
return arg_tbl
end
end
return nil
end
cfg_tbl = as_cfg_table(cfg)
if cfg_tbl then
cfg_apply = parse_cfg_table(cfg_tbl, default_name or tname_str)
elseif cfg.kind == "function" then
cfg_apply = legacy_apply
else
error("fsm!: cfg must be a function or a DSL table literal")
end
cfg_do = block("do")
cfg_do[BI.DO.BODY] = cfg_apply
end
local function inject_cfg_do(b: Block): boolean
if not b then
return false
end
if b.kind == "record_function" then
local name = b[BI.RECORD_FUNCTION.NAME]
if name and name.kind == "identifier" and name.tk == "new" then
local body = b[BI.RECORD_FUNCTION.BODY]
if body and body.kind == "statements" then
for i = 1, #body do
local s = body[i]
if s and s.kind == "local_declaration" then
local vlist = s[BI.LOCAL_DECLARATION.VARS]
local first = vlist and vlist[BI.VARIABLE_LIST.FIRST]
if first and (first.kind == "variable" or first.kind == "identifier") and first.tk == "cfg" then
table.insert(body, i + 1, cfg_do)
return true
end
end
end
table.insert(body, 1, cfg_do)
return true
end
end
end
for k, child in pairs(b) do
if type(k) == "number" and type(child) == "table" and child.kind then
if inject_cfg_do(child) then
return true
end
end
end
return false
end
local out = ```
$st
local __fsm_noop_action: $tname.Action = function(context: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context
return context
end
local function __fsm_noop_invalid(_ctx: $tname.Context, _event: $tname.Event, _from: $tname.State): boolean
return false
end
local __fsm_noop_transition: function(ctx: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context = function(context: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context
return context
end
local function __fsm_noop_trace(_msg: string)
end
local function __fsm_normalize_transition(t: $tname.Transition): $tname.Transition
t.name = t.name or (tostring(t.from) .. ":" .. tostring(t.event) .. "->" .. tostring(t.to))
return t
end
local function __fsm_build_transitions(list: {any}): {$tname.Transition}
local out: {$tname.Transition} = {}
for _, t in ipairs(list) do
table.insert(out, __fsm_normalize_transition(t as $tname.Transition))
end
return out
end
local function __fsm_collect_states_events(list: {$tname.Transition}): ({$tname.State}, {$tname.Event})
local states: {$tname.State} = {}
local events: {$tname.Event} = {}
local seen_states: {any: boolean} = {}
local seen_events: {any: boolean} = {}
for _, t in ipairs(list) do
seen_states[t.from] = seen_states[t.from] or (table.insert(states, t.from) or true)
seen_states[t.to] = seen_states[t.to] or (table.insert(states, t.to) or true)
seen_events[t.event] = seen_events[t.event] or (table.insert(events, t.event) or true)
end
return states, events
end
function $tname.transition(from: $tname.State, event: $tname.Event, to: $tname.State, guard?: $tname.Guard, action?: $tname.Action, name?: string): $tname.Transition
local t: $tname.Transition = {
from = from,
event = event,
to = to,
guard = guard,
action = action,
name = name,
}
return t
end
local function __fsm_find(self: $tname, event: $tname.Event): $tname.Transition | nil
local from = self.state
local found: any = nil
for _, raw in ipairs(self.transitions) do
local t = raw as $tname.Transition
local guard = t.guard
local ok = (t.from == from) and (t.event == event) and ((guard and (guard as $tname.Guard)(self.ctx, event, from, t.to)) or true)
found = found or (ok and t or nil)
end
return found as $tname.Transition | nil
end
function $tname.new(context: $tname.Context, initial_state?: $tname.State, opts?: $tname.Options): $tname
local cfg: $tname.Options = {
name = $tname_str,
strict = false,
initial = nil,
transitions = {},
on_invalid = __fsm_noop_invalid,
on_transition = __fsm_noop_transition,
on_enter = __fsm_noop_action,
on_exit = __fsm_noop_action,
trace = __fsm_noop_trace,
states = {},
events = {},
}
local raw_transitions = cfg.transitions
assert(raw_transitions, "fsm: missing transitions for " .. $tname_str)
local transitions = __fsm_build_transitions(raw_transitions as {any})
local name = (opts and opts.name) or cfg.name or $tname_str
local strict = (opts and opts.strict) or cfg.strict or false
local on_invalid = (opts and opts.on_invalid) or cfg.on_invalid or __fsm_noop_invalid
local on_transition = (opts and opts.on_transition) or cfg.on_transition or __fsm_noop_transition
local on_enter = (opts and opts.on_enter) or cfg.on_enter or __fsm_noop_action
local on_exit = (opts and opts.on_exit) or cfg.on_exit or __fsm_noop_action
local trace = (opts and opts.trace) or cfg.trace or __fsm_noop_trace
local states: {$tname.State} = ((opts and opts.states) or (cfg.states or {})) as {$tname.State}
local events: {$tname.Event} = ((opts and opts.events) or (cfg.events or {})) as {$tname.Event}
local auto_states, auto_events = __fsm_collect_states_events(transitions)
states = ((#states == 0) and (auto_states as {$tname.State})) or (states as {$tname.State})
events = ((#events == 0) and (auto_events as {$tname.Event})) or (events as {$tname.Event})
local init = initial_state or cfg.initial
assert(init, "fsm: missing initial state for " .. $tname_str)
local init_state = init as $tname.State
local self: $tname = {
state = init_state,
ctx = context,
transitions = transitions,
opts = {
name = name,
strict = strict,
on_invalid = on_invalid,
on_transition = on_transition,
on_enter = on_enter,
on_exit = on_exit,
trace = trace,
states = states,
events = events,
},
states = states,
events = events,
}
return setmetatable(self, { __index = $tname }) as $tname
end
function $tname:add_transition(t: $tname.Transition)
table.insert(self.transitions, __fsm_normalize_transition(t))
end
function $tname:state_is(s: $tname.State): boolean
return self.state == s
end
function $tname:can(event: $tname.Event): boolean
return __fsm_find(self, event) ~= nil
end
local function __fsm_apply_transition(self: $tname, event: $tname.Event, t: $tname.Transition): boolean
local from = self.state
local current_ctx = self.ctx
local action = t.action
current_ctx = (self.opts.on_exit as $tname.Action)(current_ctx, event, from, t.to)
current_ctx = (action and (action as $tname.Action)(current_ctx, event, from, t.to)) or current_ctx
self.state = t.to
current_ctx = (self.opts.on_enter as $tname.Action)(current_ctx, event, from, t.to)
current_ctx = (self.opts.on_transition as function(ctx: $tname.Context, event?: $tname.Event, from?: $tname.State, to?: $tname.State): $tname.Context)(current_ctx, event, from, t.to)
self.ctx = current_ctx
(self.opts.trace as function(msg: string))("fsm: " .. tostring(from) .. " --(" .. tostring(event) .. ")-> " .. tostring(self.state))
return true
end
local function __fsm_invalid(self: $tname, event: $tname.Event): boolean
local ok = (self.opts.on_invalid as function(ctx: $tname.Context, event: $tname.Event, from: $tname.State): boolean)(self.ctx, event, self.state)
assert(not self.opts.strict, "fsm: invalid transition (" .. tostring(self.state) .. ", " .. tostring(event) .. ") in " .. tostring(self.opts.name))
return ok
end
function $tname:step(event: $tname.Event): boolean
local t = __fsm_find(self, event)
return (t and __fsm_apply_transition(self, event, t)) or __fsm_invalid(self, event)
end
function $tname:dispatch(event: $tname.Event): $tname.State
self:step(event)
return self.state
end
function $tname:reset(ctx?: $tname.Context, state?: $tname.State)
self.ctx = ctx or self.ctx
self.state = state or self.state
end
```
assert(inject_cfg_do(out), "fsm!: failed to inject cfg block")
return out
end
-- Demo usage: order fulfillment workflow
fsm!(
local record Order
enum State
"Cart"
"Checkout"
"Paid"
"Packed"
"Shipped"
"Delivered"
"Canceled"
"Refunded"
end
enum Event
"Checkout"
"Pay"
"Pack"
"Ship"
"Deliver"
"Cancel"
"Refund"
"Return"
end
record Context
inventory_reserved: boolean
payment_captured: boolean
fraud_score: number
shipment_id: string | nil
refund_issued: boolean
history: {string}
end
end,
{
strict = true,
initial = Cart,
-- states/events are inferred from transitions if omitted
flow = {
Cart / Checkout >> Checkout,
(Checkout / Pay >> Paid
& when(function(ctx: Order.Context): boolean return ctx.fraud_score < 0.7 end)
| act(function(ctx: Order.Context): Order.Context
ctx.payment_captured = true
table.insert(ctx.history, "payment_captured")
return ctx
end)),
Checkout / Cancel >> Canceled,
(Paid / Pack >> Packed
& when(function(ctx: Order.Context): boolean return ctx.inventory_reserved end)
| act(function(ctx: Order.Context): Order.Context
table.insert(ctx.history, "packed")
return ctx
end)),
Packed = {
Ship = (Shipped | act(function(ctx: Order.Context): Order.Context
ctx.shipment_id = "SHIP-" .. tostring(#ctx.history + 1)
table.insert(ctx.history, "shipped:" .. ctx.shipment_id)
return ctx
end)),
},
Shipped / Deliver >> Delivered,
(Delivered / Return >> Refunded | act(function(ctx: Order.Context): Order.Context
ctx.refund_issued = true
table.insert(ctx.history, "return_refund")
return ctx
end)),
Paid = {
Cancel = (Refunded | act(function(ctx: Order.Context): Order.Context
ctx.refund_issued = true
table.insert(ctx.history, "cancel_refund")
return ctx
end)),
},
},
hooks = {
on_transition = function(ctx: Order.Context, ev: Order.Event, from: Order.State, to: Order.State): Order.Context
table.insert(ctx.history, tostring(from) .. "->" .. tostring(to) .. " via " .. tostring(ev))
return ctx
end,
invalid = function(ctx: Order.Context, ev: Order.Event, st: Order.State): boolean
table.insert(ctx.history, "invalid:" .. tostring(st) .. ":" .. tostring(ev))
return false
end,
trace = function(msg: string)
print(msg)
end,
},
}
)
local ctx: Order.Context = {
inventory_reserved = true,
payment_captured = false,
fraud_score = 0.2,
shipment_id = nil,
refund_issued = false,
history = {},
}
local order = Order.new(ctx)
print("initial", order.state)
local function send(ev: Order.Event)
print("event", ev, "from", order.state)
order:step(ev)
print("state", order.state)
end
local macro ignore!(...: Statement)
return block "statements"
end
send("Checkout")
send("Pay")
send("Pack")
send("Ship")
send("Deliver")
send("Return")
print("history", table.concat(order.ctx.history, " | "))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment