Last active
February 3, 2026 17:41
-
-
Save Frityet/265b7ea18da8a93e3566be88b17a78c9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local table = _tl_compat and _tl_compat.table or table | |
| local Order = { Context = {}, Transition = {}, Options = {}, DSL = {} }; local __fsm_noop_action = function(context, _event, _from, _to) return context end; local function __fsm_noop_invalid(_ctx, _event, _from) return false end; local __fsm_noop_transition = function(context, _event, _from, _to) return context end; local function __fsm_noop_trace(_msg) end; local function __fsm_normalize_transition(t) t.name = t.name or (tostring(t.from) .. ":" .. tostring(t.event) .. "->" .. tostring(t.to)); return t end; local function __fsm_build_transitions(list) local out = {}; for _, t in ipairs(list) do table.insert(out, __fsm_normalize_transition(t)) end; return out end; local function __fsm_collect_states_events(list) local states = {}; local events = {}; local seen_states = {}; local seen_events = {}; for _, t in ipairs(list) do seen_states[t.from] = seen_states[t.from] or (table.insert(states, t.from) or true); seen_states[t.to] = seen_states[t.to] or (table.insert(states, t.to) or true); seen_events[t.event] = seen_events[t.event] or (table.insert(events, t.event) or true) end; return states, events end; function Order.transition(from, event, to, guard, action, name) local t = { from = from, event = event, to = to, guard = guard, action = action, name = name }; return t end; local function __fsm_find(self, event) local from = self.state; local found = nil; for _, raw in ipairs(self.transitions) do local t = raw; local guard = t.guard; local ok = (t.from == from) and (t.event == event) and ((guard and (guard)(self.ctx, event, from, t.to)) or true); found = found or (ok and t or nil) end; return found end; function Order.new(context, initial_state, opts) local cfg = { name = "Order", strict = false, initial = nil, transitions = {}, on_invalid = __fsm_noop_invalid, on_transition = __fsm_noop_transition, on_enter = __fsm_noop_action, on_exit = __fsm_noop_action, trace = __fsm_noop_trace, states = {}, events = {} }; do cfg.name = "Order"; cfg.strict = true; cfg.initial = "Cart"; cfg.on_invalid = function(ctx, ev, st) table.insert(ctx.history, "invalid:" .. tostring(st) .. ":" .. tostring(ev)); return false end; cfg.on_transition = function(ctx, ev, from, to) table.insert(ctx.history, tostring(from) .. "->" .. tostring(to) .. " via " .. tostring(ev)); return ctx end; cfg.trace = function(msg) print(msg) end; table.insert(cfg.transitions, Order.transition("Cart", "Checkout", "Checkout", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Checkout", "Pay", "Paid", function(ctx) return ctx.fraud_score < 0.7 end, function(ctx) ctx.payment_captured = true; table.insert(ctx.history, "payment_captured"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Checkout", "Cancel", "Canceled", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Paid", "Pack", "Packed", function(ctx) return ctx.inventory_reserved end, function(ctx) table.insert(ctx.history, "packed"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Packed", "Ship", "Shipped", nil, function(ctx) ctx.shipment_id = "SHIP-" .. tostring(#ctx.history + 1); table.insert(ctx.history, "shipped:" .. ctx.shipment_id); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Shipped", "Deliver", "Delivered", nil, nil, nil)); table.insert(cfg.transitions, Order.transition("Delivered", "Return", "Refunded", nil, function(ctx) ctx.refund_issued = true; table.insert(ctx.history, "return_refund"); return ctx end, nil)); table.insert(cfg.transitions, Order.transition("Paid", "Cancel", "Refunded", nil, function(ctx) ctx.refund_issued = true; table.insert(ctx.history, "cancel_refund"); return ctx end, nil)) end; local raw_transitions = cfg.transitions; assert(raw_transitions, "fsm: missing transitions for " .. "Order"); local transitions = __fsm_build_transitions(raw_transitions); local name = (opts and opts.name) or cfg.name or "Order"; local strict = (opts and opts.strict) or cfg.strict or false; local on_invalid = (opts and opts.on_invalid) or cfg.on_invalid or __fsm_noop_invalid; local on_transition = (opts and opts.on_transition) or cfg.on_transition or __fsm_noop_transition; local on_enter = (opts and opts.on_enter) or cfg.on_enter or __fsm_noop_action; local on_exit = (opts and opts.on_exit) or cfg.on_exit or __fsm_noop_action; local trace = (opts and opts.trace) or cfg.trace or __fsm_noop_trace; local states = ((opts and opts.states) or (cfg.states or {})); local events = ((opts and opts.events) or (cfg.events or {})); local auto_states, auto_events = __fsm_collect_states_events(transitions); states = ((#states == 0) and (auto_states)) or (states); events = ((#events == 0) and (auto_events)) or (events); local init = initial_state or cfg.initial; assert(init, "fsm: missing initial state for " .. "Order"); local init_state = init; local self = { state = init_state, ctx = context, transitions = transitions, opts = { name = name, strict = strict, on_invalid = on_invalid, on_transition = on_transition, on_enter = on_enter, on_exit = on_exit, trace = trace, states = states, events = events }, states = states, events = events }; return setmetatable(self, { __index = Order }) end; function Order:add_transition(t) table.insert(self.transitions, __fsm_normalize_transition(t)) end; function Order:state_is(s) return self.state == s end; function Order:can(event) return __fsm_find(self, event) ~= nil end; local function __fsm_apply_transition(self, event, t) local from = self.state; local current_ctx = self.ctx; local action = t.action; current_ctx = (self.opts.on_exit)(current_ctx, event, from, t.to); current_ctx = (action and (action)(current_ctx, event, from, t.to)) or current_ctx; self.state = t.to; current_ctx = (self.opts.on_enter)(current_ctx, event, from, t.to); current_ctx = (self.opts.on_transition)(current_ctx, event, from, t.to); self.ctx = current_ctx; (self.opts.trace)("fsm: " .. tostring(from) .. " --(" .. tostring(event) .. ")-> " .. tostring(self.state)); return true end; local function __fsm_invalid(self, event) local ok = (self.opts.on_invalid)(self.ctx, event, self.state); assert(not self.opts.strict, "fsm: invalid transition (" .. tostring(self.state) .. ", " .. tostring(event) .. ") in " .. tostring(self.opts.name)); return ok end; function Order:step(event) local t = __fsm_find(self, event); return (t and __fsm_apply_transition(self, event, t)) or __fsm_invalid(self, event) end; function Order:dispatch(event) self:step(event); return self.state end; function Order:reset(ctx, state) self.ctx = ctx or self.ctx; self.state = state or self.state end | |
| local ctx = { | |
| inventory_reserved = true, | |
| payment_captured = false, | |
| fraud_score = 0.2, | |
| shipment_id = nil, | |
| refund_issued = false, | |
| history = {}, | |
| } | |
| local order = Order.new(ctx) | |
| print("initial", order.state) | |
| local function send(ev) | |
| print("event", ev, "from", order.state) | |
| order:step(ev) | |
| print("state", order.state) | |
| end | |
| send("Checkout") | |
| send("Pay") | |
| send("Pack") | |
| send("Ship") | |
| send("Deliver") | |
| send("Return") | |
| print("history", table.concat(order.ctx.history, " | ")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| -- files/src/fsm.tl | |
| -- Fully type safe finite state machine implemented with macros | |
| ------------------------------------------------------------------------------- | |
| local macro fsm!(spec: Statement, cfg: Expression) | |
| local BI = BLOCK_INDEXES | |
| local function as_statement(b: Block): Block | |
| if b.kind == "statements" then | |
| if #b ~= 1 then | |
| error("fsm!: expected a single statement") | |
| end | |
| return b[1] | |
| end | |
| return b | |
| end | |
| local function mkstr(s: string): Block | |
| local st = block("string") | |
| st.tk = string.format("%q", s) | |
| return st | |
| end | |
| local function mkident(nm: string): Block | |
| local id = block("identifier") | |
| id.tk = nm | |
| return id | |
| end | |
| local function mkvar(nm: string): Block | |
| local v = block("variable") | |
| v.tk = nm | |
| return v | |
| end | |
| local function mknominal(nm: string): Block | |
| local nt = block("nominal_type") | |
| nt.tk = nm | |
| nt[BI.NOMINAL_TYPE.NAME] = mkident(nm) | |
| return nt | |
| end | |
| local function mkarray(elem: Block): Block | |
| local at = block("array_type") | |
| at[BI.ARRAY_TYPE.ELEMENT] = elem | |
| return at | |
| end | |
| local function mkfield(name: string, typ: Block): Block | |
| local fld = block("record_field") | |
| fld[BI.RECORD_FIELD.NAME] = mkvar(name) | |
| fld[BI.RECORD_FIELD.TYPE] = typ | |
| return fld | |
| end | |
| local function stmt1(b: Block): Block | |
| if b.kind == "statements" then | |
| return b[1] | |
| end | |
| return b | |
| end | |
| local function unquote_str(tk: string): string | |
| local q = tk:sub(1, 1) | |
| if (q == "\"" or q == "'") and tk:sub(-1) == q then | |
| return tk:sub(2, -2) | |
| end | |
| if tk:sub(1, 2) == "[[" and tk:sub(-2) == "]]" then | |
| return tk:sub(3, -3) | |
| end | |
| return tk | |
| end | |
| local function unwrap_paren(exp: Block): Block | |
| local cur = exp | |
| while cur and cur.kind == "paren" and cur[BI.PAREN.EXP] do | |
| cur = cur[BI.PAREN.EXP] | |
| end | |
| return cur | |
| end | |
| local function is_ident(exp: Block): boolean | |
| return exp.kind == "identifier" or exp.kind == "variable" | |
| end | |
| local function normalize_symbol(exp: Block): Block | |
| local e = unwrap_paren(exp) | |
| if is_ident(e) then | |
| return mkstr(e.tk) | |
| end | |
| return exp | |
| end | |
| local function as_literal_table(exp: Block): Block | nil | |
| local e = unwrap_paren(exp) | |
| if e and e.kind == "literal_table" then | |
| return e | |
| end | |
| return nil | |
| end | |
| local function call_name(exp: Block): string | nil | |
| if exp.kind ~= "op_funcall" then | |
| return nil | |
| end | |
| local callee = exp[BI.OP.E1] | |
| if callee and is_ident(callee) then | |
| return callee.tk | |
| end | |
| return nil | |
| end | |
| local function call_arg1(exp: Block): Block | nil | |
| if exp.kind ~= "op_funcall" then | |
| return nil | |
| end | |
| local args = exp[BI.OP.E2] | |
| if args and args.kind == "expression_list" then | |
| return args[BI.EXPRESSION_LIST.FIRST] | |
| end | |
| return nil | |
| end | |
| local function unwrap_alias_call(exp: Block, aliases: {string: boolean}): Block | |
| local e = unwrap_paren(exp) | |
| if e.kind == "op_funcall" then | |
| local callee = e[BI.OP.E1] | |
| if callee and is_ident(callee) and aliases[callee.tk] then | |
| local args = e[BI.OP.E2] | |
| if args and args.kind == "expression_list" then | |
| local arg1 = args[BI.EXPRESSION_LIST.FIRST] | |
| if arg1 then | |
| return arg1 | |
| end | |
| end | |
| end | |
| end | |
| return exp | |
| end | |
| local function split_guard_action(exp: Block): (Block, Block | nil, Block | nil) | |
| local base = unwrap_paren(exp) | |
| local action: Block | nil = nil | |
| local guard: Block | nil = nil | |
| if base.kind == "op_bor" then | |
| action = unwrap_alias_call(base[BI.OP.E2], { ["do"] = true, act = true, action = true }) | |
| base = unwrap_paren(base[BI.OP.E1]) | |
| end | |
| if base.kind == "op_band" then | |
| guard = unwrap_alias_call(base[BI.OP.E2], { when = true, guard = true, ["if"] = true }) | |
| base = unwrap_paren(base[BI.OP.E1]) | |
| end | |
| return base, guard, action | |
| end | |
| local function mktable_list(items: {Block}, normalize_fn: function(Block): Block): Block | |
| local out = block("literal_table") | |
| local n = 1 | |
| for _, v in ipairs(items) do | |
| local it = block("literal_table_item") | |
| local key = block("integer") | |
| key.tk = tostring(n) | |
| it[BI.LITERAL_TABLE_ITEM.KEY] = key | |
| it[BI.LITERAL_TABLE_ITEM.VALUE] = normalize_fn(v) | |
| table.insert(out, it) | |
| n = n + 1 | |
| end | |
| return out | |
| end | |
| local function normalize_list_table(tbl: Block, normalize_fn: function(Block): Block): Block | |
| local out = block("literal_table") | |
| local n = 1 | |
| for _, item in ipairs(tbl) do | |
| local v = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if v then | |
| local it = block("literal_table_item") | |
| local key = block("integer") | |
| key.tk = tostring(n) | |
| it[BI.LITERAL_TABLE_ITEM.KEY] = key | |
| it[BI.LITERAL_TABLE_ITEM.VALUE] = normalize_fn(v) | |
| table.insert(out, it) | |
| n = n + 1 | |
| end | |
| end | |
| return out | |
| end | |
| local st = as_statement(spec) | |
| if not (st and st.kind == "local_type") then | |
| error("fsm!: expected a local record declaration") | |
| end | |
| local name_block = st[BI.LOCAL_TYPE.VAR] | |
| if not (name_block and name_block.kind == "identifier") then | |
| error("fsm!: record declaration must have a name") | |
| end | |
| local tname = name_block | |
| local tname_str = mkstr(name_block.tk) | |
| local record_newtype = st[BI.LOCAL_TYPE.VALUE] | |
| if not (record_newtype and record_newtype.kind == "newtype") then | |
| error("fsm!: expected a record type") | |
| end | |
| local type_decl = record_newtype[BI.NEWTYPE.TYPEDECL] | |
| if not (type_decl and type_decl.kind == "typedecl") then | |
| error("fsm!: expected a record type") | |
| end | |
| local record_def = type_decl[BI.TYPEDECL.TYPE] | |
| if not (record_def and record_def.kind == "record") then | |
| error("fsm!: expected a record type") | |
| end | |
| local body = record_def[BI.RECORD.FIELDS] | |
| if not body then | |
| error("fsm!: record missing body") | |
| end | |
| local function has_nested_type(nm: string): boolean | |
| for _, item in ipairs(body) do | |
| if item.kind == "local_type" then | |
| local v = item[BI.LOCAL_TYPE.VAR] | |
| if v and (v.kind == "identifier" or v.kind == "type_identifier") and v.tk == nm then | |
| return true | |
| end | |
| end | |
| end | |
| return false | |
| end | |
| if not has_nested_type("State") then | |
| error("fsm!: record must define a nested type 'State'") | |
| end | |
| if not has_nested_type("Event") then | |
| error("fsm!: record must define a nested type 'Event'") | |
| end | |
| if not has_nested_type("Context") then | |
| error("fsm!: record must define a nested type 'Context'") | |
| end | |
| local function parse_transition_rune(exp: Block): {string: Block} | nil | |
| local e, guard, action = split_guard_action(exp) | |
| if e.kind == "op_shr" or e.kind == "op_shl" then | |
| local left = unwrap_paren(e[BI.OP.E1]) | |
| local to = e[BI.OP.E2] | |
| if left and left.kind == "op_div" then | |
| local from = left[BI.OP.E1] | |
| local event = left[BI.OP.E2] | |
| if from and event and to then | |
| return { | |
| from = normalize_symbol(from), | |
| event = normalize_symbol(event), | |
| to = normalize_symbol(to), | |
| guard = guard, | |
| action = action, | |
| } | |
| end | |
| end | |
| end | |
| return nil | |
| end | |
| local function parse_transition_chain(exp: Block): {string: Block} | nil | |
| local cur = unwrap_paren(exp) | |
| local methods: {any} = {} | |
| while cur and cur.kind == "op_funcall" do | |
| local callee = cur[BI.OP.E1] | |
| if not (callee and (callee.kind == "op_colon" or callee.kind == "op_dot")) then | |
| break | |
| end | |
| local mname = callee[BI.OP.E2] | |
| if not (mname and is_ident(mname)) then | |
| break | |
| end | |
| local args = cur[BI.OP.E2] | |
| local arg1: Block | nil = nil | |
| if args and args.kind == "expression_list" then | |
| arg1 = args[BI.EXPRESSION_LIST.FIRST] | |
| end | |
| table.insert(methods, 1, { name = mname.tk, arg = arg1 }) | |
| cur = callee[BI.OP.E1] | |
| end | |
| if #methods == 0 then | |
| return nil | |
| end | |
| local spec: {string: Block} = { from = normalize_symbol(cur) } | |
| for _, m in ipairs(methods) do | |
| if (m.name == "on" or m.name == "via" or m.name == "event") and m.arg then | |
| spec.event = normalize_symbol(m.arg) | |
| elseif (m.name == "to" or m.name == "into" or m.name == "then") and m.arg then | |
| spec.to = normalize_symbol(m.arg) | |
| elseif (m.name == "guard" or m.name == "when" or m.name == "if") and m.arg then | |
| spec.guard = m.arg | |
| elseif (m.name == "do" or m.name == "act" or m.name == "action") and m.arg then | |
| spec.action = m.arg | |
| elseif (m.name == "name" or m.name == "named") and m.arg then | |
| spec.name = normalize_symbol(m.arg) | |
| end | |
| end | |
| if not spec.event then | |
| return nil | |
| end | |
| return spec | |
| end | |
| local function parse_transition_table(tbl: Block): {string: Block} | nil | |
| local spec: {string: Block} = {} | |
| for _, item in ipairs(tbl) do | |
| local key = item[BI.LITERAL_TABLE_ITEM.KEY] | |
| local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if key and key.kind == "string" and val then | |
| local k = unquote_str(key.tk) | |
| if k == "from" then | |
| spec.from = normalize_symbol(val) | |
| elseif k == "event" or k == "on" then | |
| spec.event = normalize_symbol(val) | |
| elseif k == "to" then | |
| spec.to = normalize_symbol(val) | |
| elseif k == "guard" or k == "when" then | |
| spec.guard = unwrap_alias_call(val, { when = true, guard = true, ["if"] = true }) | |
| elseif k == "action" or k == "do" then | |
| spec.action = unwrap_alias_call(val, { ["do"] = true, act = true, action = true }) | |
| elseif k == "name" then | |
| spec.name = normalize_symbol(val) | |
| end | |
| end | |
| end | |
| if spec.from and spec.event and spec.to then | |
| return spec | |
| end | |
| return nil | |
| end | |
| local function parse_transition_expr(expr: Block, value_opt: Block | nil): {string: Block} | nil | |
| local e = unwrap_paren(expr) | |
| local spec = parse_transition_rune(e) | |
| if not spec then | |
| spec = parse_transition_chain(e) | |
| end | |
| if not spec and e.kind == "literal_table" then | |
| spec = parse_transition_table(e) | |
| end | |
| if spec then | |
| if not spec.to and value_opt then | |
| spec.to = normalize_symbol(value_opt) | |
| end | |
| if spec.guard then | |
| spec.guard = unwrap_alias_call(spec.guard, { when = true, guard = true, ["if"] = true }) | |
| end | |
| if spec.action then | |
| spec.action = unwrap_alias_call(spec.action, { ["do"] = true, act = true, action = true }) | |
| end | |
| if spec.from and spec.event and spec.to then | |
| return spec | |
| end | |
| end | |
| return nil | |
| end | |
| local function parse_hooks_table(tbl: Block, cfg_info: {string: any}) | |
| for _, item in ipairs(tbl) do | |
| local key = item[BI.LITERAL_TABLE_ITEM.KEY] | |
| local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if key and key.kind == "string" and val then | |
| local k = unquote_str(key.tk) | |
| if k == "invalid" or k == "on_invalid" then | |
| cfg_info.on_invalid = val | |
| elseif k == "transition" or k == "on_transition" then | |
| cfg_info.on_transition = val | |
| elseif k == "enter" or k == "on_enter" then | |
| cfg_info.on_enter = val | |
| elseif k == "exit" or k == "on_exit" then | |
| cfg_info.on_exit = val | |
| elseif k == "trace" then | |
| cfg_info.trace = val | |
| end | |
| end | |
| end | |
| end | |
| local function parse_transitions_table(tbl: Block, cfg_info: {string: any}) | |
| for _, item in ipairs(tbl) do | |
| local key = item[BI.LITERAL_TABLE_ITEM.KEY] | |
| local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if key and key.kind == "integer" then | |
| if val then | |
| local spec = parse_transition_expr(val, nil) | |
| if spec then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| else | |
| if key and val then | |
| local nested = as_literal_table(val) | |
| if nested then | |
| local from = normalize_symbol(key) | |
| for _, sub in ipairs(nested) do | |
| local sk = sub[BI.LITERAL_TABLE_ITEM.KEY] | |
| local sv = sub[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or sub[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if sk and sk.kind == "integer" then | |
| if sv then | |
| local spec = parse_transition_expr(sv, nil) | |
| if spec then | |
| spec.from = spec.from or from | |
| if spec.from and spec.event and spec.to then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| end | |
| else | |
| if sk and sv then | |
| local event = normalize_symbol(sk) | |
| local spec = parse_transition_expr(sv, nil) | |
| if not spec then | |
| local base, guard, action = split_guard_action(sv) | |
| spec = { | |
| from = from, | |
| event = event, | |
| to = normalize_symbol(base), | |
| guard = guard, | |
| action = action, | |
| } | |
| else | |
| spec.from = spec.from or from | |
| spec.event = spec.event or event | |
| if not spec.to then | |
| local base, guard, action = split_guard_action(sv) | |
| spec.to = normalize_symbol(base) | |
| spec.guard = spec.guard or guard | |
| spec.action = spec.action or action | |
| end | |
| end | |
| if spec and spec.from and spec.event and spec.to then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| end | |
| end | |
| else | |
| local spec = parse_transition_expr(key, val) | |
| if spec then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| end | |
| end | |
| end | |
| end | |
| local function parse_cfg_table(tbl: Block, default_name: Block | nil): Block | |
| local cfg_info: {string: any} = { | |
| transitions = {}, | |
| state_items = {}, | |
| event_items = {}, | |
| } | |
| local function set_field(key: string, value: Block) | |
| if key == "name" or key == "title" then | |
| cfg_info.name = normalize_symbol(value) | |
| elseif key == "strict" then | |
| cfg_info.strict = value | |
| elseif key == "initial" or key == "start" then | |
| cfg_info.initial = normalize_symbol(value) | |
| elseif key == "states" then | |
| local list = as_literal_table(value) | |
| cfg_info.states = list and normalize_list_table(list, normalize_symbol) or value | |
| elseif key == "events" then | |
| local list = as_literal_table(value) | |
| cfg_info.events = list and normalize_list_table(list, normalize_symbol) or value | |
| elseif key == "flow" or key == "transitions" then | |
| local flow_tbl = as_literal_table(value) | |
| if flow_tbl then | |
| parse_transitions_table(flow_tbl, cfg_info) | |
| end | |
| elseif key == "hooks" or key == "on" then | |
| local hook_tbl = as_literal_table(value) | |
| if hook_tbl then | |
| parse_hooks_table(hook_tbl, cfg_info) | |
| end | |
| elseif key == "on_invalid" or key == "invalid" then | |
| cfg_info.on_invalid = value | |
| elseif key == "on_transition" or key == "transition" then | |
| cfg_info.on_transition = value | |
| elseif key == "on_enter" or key == "enter" then | |
| cfg_info.on_enter = value | |
| elseif key == "on_exit" or key == "exit" then | |
| cfg_info.on_exit = value | |
| elseif key == "trace" then | |
| cfg_info.trace = value | |
| end | |
| end | |
| local function parse_directive(expr: Block) | |
| local e = unwrap_paren(expr) | |
| local handled = false | |
| if e.kind == "op_funcall" then | |
| local fname = call_name(e) | |
| local arg1 = call_arg1(e) | |
| if fname == "name" or fname == "title" then | |
| if arg1 then cfg_info.name = normalize_symbol(arg1) end | |
| handled = true | |
| elseif fname == "strict" then | |
| cfg_info.strict = arg1 or `true` | |
| handled = true | |
| elseif fname == "initial" or fname == "start" then | |
| if arg1 then cfg_info.initial = normalize_symbol(arg1) end | |
| handled = true | |
| elseif fname == "states" then | |
| local list = arg1 and as_literal_table(arg1) | |
| cfg_info.states = list and normalize_list_table(list, normalize_symbol) or arg1 | |
| handled = true | |
| elseif fname == "events" then | |
| local list = arg1 and as_literal_table(arg1) | |
| cfg_info.events = list and normalize_list_table(list, normalize_symbol) or arg1 | |
| handled = true | |
| elseif fname == "state" then | |
| if arg1 then table.insert(cfg_info.state_items as {any}, normalize_symbol(arg1)) end | |
| handled = true | |
| elseif fname == "event" then | |
| if arg1 then table.insert(cfg_info.event_items as {any}, normalize_symbol(arg1)) end | |
| handled = true | |
| elseif fname == "flow" or fname == "transitions" then | |
| local flow_tbl = arg1 and as_literal_table(arg1) | |
| if flow_tbl then | |
| parse_transitions_table(flow_tbl, cfg_info) | |
| end | |
| handled = true | |
| elseif fname == "hooks" or fname == "on" then | |
| local hook_tbl = arg1 and as_literal_table(arg1) | |
| if hook_tbl then | |
| parse_hooks_table(hook_tbl, cfg_info) | |
| end | |
| handled = true | |
| elseif fname == "trace" then | |
| if arg1 then cfg_info.trace = arg1 end | |
| handled = true | |
| elseif fname == "on_invalid" then | |
| if arg1 then cfg_info.on_invalid = arg1 end | |
| handled = true | |
| elseif fname == "on_transition" then | |
| if arg1 then cfg_info.on_transition = arg1 end | |
| handled = true | |
| elseif fname == "on_enter" then | |
| if arg1 then cfg_info.on_enter = arg1 end | |
| handled = true | |
| elseif fname == "on_exit" then | |
| if arg1 then cfg_info.on_exit = arg1 end | |
| handled = true | |
| end | |
| elseif is_ident(e) and e.tk == "strict" then | |
| cfg_info.strict = `true` | |
| handled = true | |
| end | |
| if not handled then | |
| local spec = parse_transition_expr(e, nil) | |
| if spec then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| end | |
| for _, item in ipairs(tbl) do | |
| local key = item[BI.LITERAL_TABLE_ITEM.KEY] | |
| local val = item[BI.LITERAL_TABLE_ITEM.TYPED_VALUE] or item[BI.LITERAL_TABLE_ITEM.VALUE] | |
| if key and key.kind == "integer" then | |
| if val then | |
| parse_directive(val) | |
| end | |
| elseif key and key.kind == "string" and val then | |
| set_field(unquote_str(key.tk), val) | |
| else | |
| if key and val then | |
| local spec = parse_transition_expr(key, val) | |
| if spec then | |
| table.insert(cfg_info.transitions as {any}, spec) | |
| end | |
| end | |
| end | |
| end | |
| if not cfg_info.name and default_name then | |
| cfg_info.name = default_name | |
| end | |
| local stmts = block("statements") | |
| local function push_stmt(b: Block) | |
| table.insert(stmts, stmt1(b)) | |
| end | |
| if cfg_info.name then | |
| local name = cfg_info.name | |
| push_stmt(``` cfg.name = $name ```) | |
| end | |
| if cfg_info.strict then | |
| local strict = cfg_info.strict | |
| push_stmt(``` cfg.strict = $strict ```) | |
| end | |
| if cfg_info.initial then | |
| local initial = cfg_info.initial | |
| push_stmt(``` cfg.initial = $initial ```) | |
| end | |
| if cfg_info.states then | |
| local states = cfg_info.states | |
| push_stmt(``` cfg.states = $states ```) | |
| elseif cfg_info.state_items and #cfg_info.state_items > 0 then | |
| local list = mktable_list(cfg_info.state_items as {Block}, normalize_symbol) | |
| push_stmt(``` cfg.states = $list ```) | |
| end | |
| if cfg_info.events then | |
| local events = cfg_info.events | |
| push_stmt(``` cfg.events = $events ```) | |
| elseif cfg_info.event_items and #cfg_info.event_items > 0 then | |
| local list = mktable_list(cfg_info.event_items as {Block}, normalize_symbol) | |
| push_stmt(``` cfg.events = $list ```) | |
| end | |
| if cfg_info.on_invalid then | |
| local on_invalid = cfg_info.on_invalid | |
| push_stmt(``` cfg.on_invalid = $on_invalid ```) | |
| end | |
| if cfg_info.on_transition then | |
| local on_transition = cfg_info.on_transition | |
| push_stmt(``` cfg.on_transition = $on_transition ```) | |
| end | |
| if cfg_info.on_enter then | |
| local on_enter = cfg_info.on_enter | |
| push_stmt(``` cfg.on_enter = $on_enter ```) | |
| end | |
| if cfg_info.on_exit then | |
| local on_exit = cfg_info.on_exit | |
| push_stmt(``` cfg.on_exit = $on_exit ```) | |
| end | |
| if cfg_info.trace then | |
| local trace = cfg_info.trace | |
| push_stmt(``` cfg.trace = $trace ```) | |
| end | |
| for _, raw in ipairs(cfg_info.transitions) do | |
| local spec = raw | |
| local guard = (spec.guard and spec.guard) or `nil` | |
| local action = (spec.action and spec.action) or `nil` | |
| local name = (spec.name and spec.name) or `nil` | |
| local from = spec.from | |
| local event = spec.event | |
| local to = spec.to | |
| push_stmt(``` | |
| table.insert(cfg.transitions, $tname.transition($from, $event, $to, $guard, $action, $name)) | |
| ```) | |
| end | |
| return stmts | |
| end | |
| local guard_decl = stmt1(``` | |
| local type Guard = function(ctx: Context, event?: Event, from?: State, to?: State): boolean | |
| ```) | |
| local action_decl = stmt1(``` | |
| local type Action = function(ctx: Context, event?: Event, from?: State, to?: State): Context | |
| ```) | |
| local transition_decl = stmt1(``` | |
| local type Transition = record | |
| from: State | |
| event: Event | |
| to: State | |
| guard: Guard | nil | |
| action: Action | nil | |
| name: string | nil | |
| end | |
| ```) | |
| local options_decl = stmt1(``` | |
| local type Options = record | |
| name: string | nil | |
| strict: boolean | nil | |
| initial: State | nil | |
| transitions: {any} | nil | |
| on_invalid: function(ctx: Context, event: Event, from: State): boolean | nil | |
| on_transition: function(ctx: Context, event: Event, from: State, to: State): Context | nil | |
| on_enter: Action | nil | |
| on_exit: Action | nil | |
| trace: function(msg: string) | nil | |
| states: {State} | nil | |
| events: {Event} | nil | |
| end | |
| ```) | |
| local dsl_decl = stmt1(``` | |
| local type DSL = record | |
| name: function(string) | |
| strict: function(boolean) | |
| initial: function(State) | |
| state: function(State) | |
| event: function(Event) | |
| states: function({State}) | |
| events: function({Event}) | |
| on_invalid: function(function(Context, Event, State): boolean) | |
| on_transition: function(function(Context, Event, State, State): Context) | |
| on_enter: function(Action) | |
| on_exit: function(Action) | |
| trace: function(function(string)) | |
| transition: function(State, Event, State, ...: any) | |
| end | |
| ```) | |
| table.insert(body, guard_decl) | |
| table.insert(body, action_decl) | |
| table.insert(body, transition_decl) | |
| table.insert(body, options_decl) | |
| table.insert(body, dsl_decl) | |
| table.insert(body, mkfield("state", mknominal("State"))) | |
| table.insert(body, mkfield("ctx", mknominal("Context"))) | |
| table.insert(body, mkfield("transitions", mkarray(mknominal("any")))) | |
| table.insert(body, mkfield("opts", mknominal("Options"))) | |
| table.insert(body, mkfield("states", mkarray(mknominal("State")))) | |
| table.insert(body, mkfield("events", mkarray(mknominal("Event")))) | |
| local cfg_apply: Block | |
| local cfg_do: Block | |
| do | |
| local default_name: Block | nil = nil | |
| local cfg_tbl: Block | nil = nil | |
| local function as_cfg_table(exp: Block): Block | nil | |
| local e = unwrap_paren(exp) | |
| local tbl = as_literal_table(e) | |
| if tbl then | |
| return tbl | |
| end | |
| if e.kind == "op_funcall" then | |
| local arg1 = call_arg1(e) | |
| local arg_tbl = arg1 and as_literal_table(arg1) | |
| if arg_tbl then | |
| local callee = e[BI.OP.E1] | |
| if callee and is_ident(callee) then | |
| default_name = mkstr(callee.tk) | |
| end | |
| return arg_tbl | |
| end | |
| end | |
| return nil | |
| end | |
| cfg_tbl = as_cfg_table(cfg) | |
| if cfg_tbl then | |
| cfg_apply = parse_cfg_table(cfg_tbl, default_name or tname_str) | |
| elseif cfg.kind == "function" then | |
| cfg_apply = legacy_apply | |
| else | |
| error("fsm!: cfg must be a function or a DSL table literal") | |
| end | |
| cfg_do = block("do") | |
| cfg_do[BI.DO.BODY] = cfg_apply | |
| end | |
| local function inject_cfg_do(b: Block): boolean | |
| if not b then | |
| return false | |
| end | |
| if b.kind == "record_function" then | |
| local name = b[BI.RECORD_FUNCTION.NAME] | |
| if name and name.kind == "identifier" and name.tk == "new" then | |
| local body = b[BI.RECORD_FUNCTION.BODY] | |
| if body and body.kind == "statements" then | |
| for i = 1, #body do | |
| local s = body[i] | |
| if s and s.kind == "local_declaration" then | |
| local vlist = s[BI.LOCAL_DECLARATION.VARS] | |
| local first = vlist and vlist[BI.VARIABLE_LIST.FIRST] | |
| if first and (first.kind == "variable" or first.kind == "identifier") and first.tk == "cfg" then | |
| table.insert(body, i + 1, cfg_do) | |
| return true | |
| end | |
| end | |
| end | |
| table.insert(body, 1, cfg_do) | |
| return true | |
| end | |
| end | |
| end | |
| for k, child in pairs(b) do | |
| if type(k) == "number" and type(child) == "table" and child.kind then | |
| if inject_cfg_do(child) then | |
| return true | |
| end | |
| end | |
| end | |
| return false | |
| end | |
| local out = ``` | |
| $st | |
| local __fsm_noop_action: $tname.Action = function(context: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context | |
| return context | |
| end | |
| local function __fsm_noop_invalid(_ctx: $tname.Context, _event: $tname.Event, _from: $tname.State): boolean | |
| return false | |
| end | |
| local __fsm_noop_transition: function(ctx: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context = function(context: $tname.Context, _event?: $tname.Event, _from?: $tname.State, _to?: $tname.State): $tname.Context | |
| return context | |
| end | |
| local function __fsm_noop_trace(_msg: string) | |
| end | |
| local function __fsm_normalize_transition(t: $tname.Transition): $tname.Transition | |
| t.name = t.name or (tostring(t.from) .. ":" .. tostring(t.event) .. "->" .. tostring(t.to)) | |
| return t | |
| end | |
| local function __fsm_build_transitions(list: {any}): {$tname.Transition} | |
| local out: {$tname.Transition} = {} | |
| for _, t in ipairs(list) do | |
| table.insert(out, __fsm_normalize_transition(t as $tname.Transition)) | |
| end | |
| return out | |
| end | |
| local function __fsm_collect_states_events(list: {$tname.Transition}): ({$tname.State}, {$tname.Event}) | |
| local states: {$tname.State} = {} | |
| local events: {$tname.Event} = {} | |
| local seen_states: {any: boolean} = {} | |
| local seen_events: {any: boolean} = {} | |
| for _, t in ipairs(list) do | |
| seen_states[t.from] = seen_states[t.from] or (table.insert(states, t.from) or true) | |
| seen_states[t.to] = seen_states[t.to] or (table.insert(states, t.to) or true) | |
| seen_events[t.event] = seen_events[t.event] or (table.insert(events, t.event) or true) | |
| end | |
| return states, events | |
| end | |
| function $tname.transition(from: $tname.State, event: $tname.Event, to: $tname.State, guard?: $tname.Guard, action?: $tname.Action, name?: string): $tname.Transition | |
| local t: $tname.Transition = { | |
| from = from, | |
| event = event, | |
| to = to, | |
| guard = guard, | |
| action = action, | |
| name = name, | |
| } | |
| return t | |
| end | |
| local function __fsm_find(self: $tname, event: $tname.Event): $tname.Transition | nil | |
| local from = self.state | |
| local found: any = nil | |
| for _, raw in ipairs(self.transitions) do | |
| local t = raw as $tname.Transition | |
| local guard = t.guard | |
| local ok = (t.from == from) and (t.event == event) and ((guard and (guard as $tname.Guard)(self.ctx, event, from, t.to)) or true) | |
| found = found or (ok and t or nil) | |
| end | |
| return found as $tname.Transition | nil | |
| end | |
| function $tname.new(context: $tname.Context, initial_state?: $tname.State, opts?: $tname.Options): $tname | |
| local cfg: $tname.Options = { | |
| name = $tname_str, | |
| strict = false, | |
| initial = nil, | |
| transitions = {}, | |
| on_invalid = __fsm_noop_invalid, | |
| on_transition = __fsm_noop_transition, | |
| on_enter = __fsm_noop_action, | |
| on_exit = __fsm_noop_action, | |
| trace = __fsm_noop_trace, | |
| states = {}, | |
| events = {}, | |
| } | |
| local raw_transitions = cfg.transitions | |
| assert(raw_transitions, "fsm: missing transitions for " .. $tname_str) | |
| local transitions = __fsm_build_transitions(raw_transitions as {any}) | |
| local name = (opts and opts.name) or cfg.name or $tname_str | |
| local strict = (opts and opts.strict) or cfg.strict or false | |
| local on_invalid = (opts and opts.on_invalid) or cfg.on_invalid or __fsm_noop_invalid | |
| local on_transition = (opts and opts.on_transition) or cfg.on_transition or __fsm_noop_transition | |
| local on_enter = (opts and opts.on_enter) or cfg.on_enter or __fsm_noop_action | |
| local on_exit = (opts and opts.on_exit) or cfg.on_exit or __fsm_noop_action | |
| local trace = (opts and opts.trace) or cfg.trace or __fsm_noop_trace | |
| local states: {$tname.State} = ((opts and opts.states) or (cfg.states or {})) as {$tname.State} | |
| local events: {$tname.Event} = ((opts and opts.events) or (cfg.events or {})) as {$tname.Event} | |
| local auto_states, auto_events = __fsm_collect_states_events(transitions) | |
| states = ((#states == 0) and (auto_states as {$tname.State})) or (states as {$tname.State}) | |
| events = ((#events == 0) and (auto_events as {$tname.Event})) or (events as {$tname.Event}) | |
| local init = initial_state or cfg.initial | |
| assert(init, "fsm: missing initial state for " .. $tname_str) | |
| local init_state = init as $tname.State | |
| local self: $tname = { | |
| state = init_state, | |
| ctx = context, | |
| transitions = transitions, | |
| opts = { | |
| name = name, | |
| strict = strict, | |
| on_invalid = on_invalid, | |
| on_transition = on_transition, | |
| on_enter = on_enter, | |
| on_exit = on_exit, | |
| trace = trace, | |
| states = states, | |
| events = events, | |
| }, | |
| states = states, | |
| events = events, | |
| } | |
| return setmetatable(self, { __index = $tname }) as $tname | |
| end | |
| function $tname:add_transition(t: $tname.Transition) | |
| table.insert(self.transitions, __fsm_normalize_transition(t)) | |
| end | |
| function $tname:state_is(s: $tname.State): boolean | |
| return self.state == s | |
| end | |
| function $tname:can(event: $tname.Event): boolean | |
| return __fsm_find(self, event) ~= nil | |
| end | |
| local function __fsm_apply_transition(self: $tname, event: $tname.Event, t: $tname.Transition): boolean | |
| local from = self.state | |
| local current_ctx = self.ctx | |
| local action = t.action | |
| current_ctx = (self.opts.on_exit as $tname.Action)(current_ctx, event, from, t.to) | |
| current_ctx = (action and (action as $tname.Action)(current_ctx, event, from, t.to)) or current_ctx | |
| self.state = t.to | |
| current_ctx = (self.opts.on_enter as $tname.Action)(current_ctx, event, from, t.to) | |
| current_ctx = (self.opts.on_transition as function(ctx: $tname.Context, event?: $tname.Event, from?: $tname.State, to?: $tname.State): $tname.Context)(current_ctx, event, from, t.to) | |
| self.ctx = current_ctx | |
| (self.opts.trace as function(msg: string))("fsm: " .. tostring(from) .. " --(" .. tostring(event) .. ")-> " .. tostring(self.state)) | |
| return true | |
| end | |
| local function __fsm_invalid(self: $tname, event: $tname.Event): boolean | |
| local ok = (self.opts.on_invalid as function(ctx: $tname.Context, event: $tname.Event, from: $tname.State): boolean)(self.ctx, event, self.state) | |
| assert(not self.opts.strict, "fsm: invalid transition (" .. tostring(self.state) .. ", " .. tostring(event) .. ") in " .. tostring(self.opts.name)) | |
| return ok | |
| end | |
| function $tname:step(event: $tname.Event): boolean | |
| local t = __fsm_find(self, event) | |
| return (t and __fsm_apply_transition(self, event, t)) or __fsm_invalid(self, event) | |
| end | |
| function $tname:dispatch(event: $tname.Event): $tname.State | |
| self:step(event) | |
| return self.state | |
| end | |
| function $tname:reset(ctx?: $tname.Context, state?: $tname.State) | |
| self.ctx = ctx or self.ctx | |
| self.state = state or self.state | |
| end | |
| ``` | |
| assert(inject_cfg_do(out), "fsm!: failed to inject cfg block") | |
| return out | |
| end | |
| -- Demo usage: order fulfillment workflow | |
| fsm!( | |
| local record Order | |
| enum State | |
| "Cart" | |
| "Checkout" | |
| "Paid" | |
| "Packed" | |
| "Shipped" | |
| "Delivered" | |
| "Canceled" | |
| "Refunded" | |
| end | |
| enum Event | |
| "Checkout" | |
| "Pay" | |
| "Pack" | |
| "Ship" | |
| "Deliver" | |
| "Cancel" | |
| "Refund" | |
| "Return" | |
| end | |
| record Context | |
| inventory_reserved: boolean | |
| payment_captured: boolean | |
| fraud_score: number | |
| shipment_id: string | nil | |
| refund_issued: boolean | |
| history: {string} | |
| end | |
| end, | |
| { | |
| strict = true, | |
| initial = Cart, | |
| -- states/events are inferred from transitions if omitted | |
| flow = { | |
| Cart / Checkout >> Checkout, | |
| (Checkout / Pay >> Paid | |
| & when(function(ctx: Order.Context): boolean return ctx.fraud_score < 0.7 end) | |
| | act(function(ctx: Order.Context): Order.Context | |
| ctx.payment_captured = true | |
| table.insert(ctx.history, "payment_captured") | |
| return ctx | |
| end)), | |
| Checkout / Cancel >> Canceled, | |
| (Paid / Pack >> Packed | |
| & when(function(ctx: Order.Context): boolean return ctx.inventory_reserved end) | |
| | act(function(ctx: Order.Context): Order.Context | |
| table.insert(ctx.history, "packed") | |
| return ctx | |
| end)), | |
| Packed = { | |
| Ship = (Shipped | act(function(ctx: Order.Context): Order.Context | |
| ctx.shipment_id = "SHIP-" .. tostring(#ctx.history + 1) | |
| table.insert(ctx.history, "shipped:" .. ctx.shipment_id) | |
| return ctx | |
| end)), | |
| }, | |
| Shipped / Deliver >> Delivered, | |
| (Delivered / Return >> Refunded | act(function(ctx: Order.Context): Order.Context | |
| ctx.refund_issued = true | |
| table.insert(ctx.history, "return_refund") | |
| return ctx | |
| end)), | |
| Paid = { | |
| Cancel = (Refunded | act(function(ctx: Order.Context): Order.Context | |
| ctx.refund_issued = true | |
| table.insert(ctx.history, "cancel_refund") | |
| return ctx | |
| end)), | |
| }, | |
| }, | |
| hooks = { | |
| on_transition = function(ctx: Order.Context, ev: Order.Event, from: Order.State, to: Order.State): Order.Context | |
| table.insert(ctx.history, tostring(from) .. "->" .. tostring(to) .. " via " .. tostring(ev)) | |
| return ctx | |
| end, | |
| invalid = function(ctx: Order.Context, ev: Order.Event, st: Order.State): boolean | |
| table.insert(ctx.history, "invalid:" .. tostring(st) .. ":" .. tostring(ev)) | |
| return false | |
| end, | |
| trace = function(msg: string) | |
| print(msg) | |
| end, | |
| }, | |
| } | |
| ) | |
| local ctx: Order.Context = { | |
| inventory_reserved = true, | |
| payment_captured = false, | |
| fraud_score = 0.2, | |
| shipment_id = nil, | |
| refund_issued = false, | |
| history = {}, | |
| } | |
| local order = Order.new(ctx) | |
| print("initial", order.state) | |
| local function send(ev: Order.Event) | |
| print("event", ev, "from", order.state) | |
| order:step(ev) | |
| print("state", order.state) | |
| end | |
| local macro ignore!(...: Statement) | |
| return block "statements" | |
| end | |
| send("Checkout") | |
| send("Pay") | |
| send("Pack") | |
| send("Ship") | |
| send("Deliver") | |
| send("Return") | |
| print("history", table.concat(order.ctx.history, " | ")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment