From ee1e659b0b41a808039e4af0849e43e6474d41b9 Mon Sep 17 00:00:00 2001 From: Gustavo Delerue Date: Fri, 13 Mar 2026 17:09:34 +0000 Subject: [PATCH] PR: Cfold refactoring --- src/ecPV.ml | 2 + src/ecParser.mly | 5 +- src/ecParsetree.ml | 2 +- src/ecUtils.ml | 13 ++ src/ecUtils.mli | 4 + src/phl/ecPhlCodeTx.ml | 327 +++++++++++++++++++++++++++------------- src/phl/ecPhlCodeTx.mli | 4 +- src/phl/ecPhlLoopTx.ml | 2 +- tests/cfold.ec | 62 ++++++++ 9 files changed, 313 insertions(+), 108 deletions(-) diff --git a/src/ecPV.ml b/src/ecPV.ml index c1c297d754..6a570d3473 100644 --- a/src/ecPV.ml +++ b/src/ecPV.ml @@ -47,6 +47,8 @@ module PVMap = struct Mnpv.find_opt (pvm m.pvm_env k) m.pvm_map let raw m = m.pvm_map + + end (* -------------------------------------------------------------------- *) diff --git a/src/ecParser.mly b/src/ecParser.mly index 52e9f4b499..c6646eac5d 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -2991,7 +2991,10 @@ interleave_info: { Pinterleave info } | CFOLD s=side? c=codepos n=word? - { Pcfold (s, c, n) } + { Pcfold (s, c, n, false) } + +| CFOLD STAR s=side? c=codepos n=word? + { Pcfold (s, c, n, true) } | RND s=side? info=rnd_info c=prefix(COLON, semrndpos)? { Prnd (s, c, info) } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 8e12874e82..3bf1007966 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -757,7 +757,7 @@ type phltactic = | Pcond of pcond_info | Pmatch of matchmode | Pswap of ((oside * pswap_kind) located list) - | Pcfold of (oside * pcodepos * int option) + | Pcfold of (oside * pcodepos * int option * bool) (* side + 1st inst + n lines + eager? *) | Pinline of inline_info | Poutline of outline_info | Pinterleave of interleave_info located diff --git a/src/ecUtils.ml b/src/ecUtils.ml index 0cd31828dd..58ad7062d1 100644 --- a/src/ecUtils.ml +++ b/src/ecUtils.ml @@ -626,6 +626,19 @@ module List = struct end in fun state xs -> aux state [] xs + + (* FIXME: REMOVE *) + let fold_left_map_filter_while (f: 'a -> 'b -> ('a * ('c option)) interruptible) = + let rec aux (state: 'a) (acc: 'c list) (xs : 'b list) = + match xs with + | [] -> (state, List.rev acc, []) + | y :: ys -> begin + match f state y with + | `Continue (state, Some y) -> aux state (y :: acc) ys + | `Continue (state, None) -> aux state acc ys + | `Interrupt -> (state, List.rev acc, xs) + end + in fun state xs -> aux state [] xs end (* -------------------------------------------------------------------- *) diff --git a/src/ecUtils.mli b/src/ecUtils.mli index 3141d814b6..7fe2c8f3a9 100644 --- a/src/ecUtils.mli +++ b/src/ecUtils.mli @@ -309,6 +309,10 @@ module List : sig ('a -> 'b -> [`Interrupt | `Continue of 'a * 'c]) -> 'a -> 'b list -> 'a * 'c list * 'b list + val fold_left_map_filter_while : + ('a -> 'b -> [`Interrupt | `Continue of 'a * ('c option)]) + -> 'a -> 'b list -> 'a * 'c list * 'b list + (* ------------------------------------------------------------------ *) val ksort: ?stable:bool -> ?rev:bool diff --git a/src/phl/ecPhlCodeTx.ml b/src/phl/ecPhlCodeTx.ml index 2968953af1..88e43b022d 100644 --- a/src/phl/ecPhlCodeTx.ml +++ b/src/phl/ecPhlCodeTx.ml @@ -185,10 +185,70 @@ let t_set_match_r (side : oside) (cpos : Position.codepos) (id : symbol) pattern (t_zip (set_match_stmt id pattern)) tc (* -------------------------------------------------------------------- *) -let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) (zpr : Zpr.zipper) = +let split_assignment (lv: lvalue) (e: expr) : ((prog_var * ty) * expr) list = + match lv, e with + | LvVar lv, e -> [(lv, e)] + | LvTuple lvs, { e_node = Etuple es } -> + List.combine lvs es + | LvTuple lvs, e -> + List.mapi (fun i (pv, ty) -> + ((pv, ty), e_proj_simpl e i ty)) lvs + + +(* + Works at the block level. + Starts from the distinguished instruction given to it and does the following: + Keep two sets: + - propagate: set of variables to constant fold and propagate + - preserve : set of variables whose value we need to preserve + proceed through the instructions and do the following: + - if instruction is an assignment: + - if it is assigning to something in preserve: + STOP + - if it is assigning to something in propagate: + update value of propagated variable, update preserve + - preserve is now the set of variables read in the + assigning expression after propagation and simplification + - otherwise: + inline values of variables in propagate (and possibly simplify) + - if instruction is control flow + = this means (possibly not exhaustively) calls, whiles, ifs, matches + - if no variables in propagate or preserve are written in the block: + - propagate value to body and proceed + - otherwise: + STOP + - if instruction is random sampling and variable is in propagate or preserve: + STOP + - when stopping: + add assignments to the code for each of the values in propagate + setting them to the value they have in the substitution + + If eager: + - we do not keep preserve and instead merge preserve into propagate. + When encountering a variable we would need to preserve we add it + to the propagation set. + - we can always keep going except for control flow + = in the case of ifs we can also inline this (very eager mode?) + = in the case of whiles it is very non-trivial how to do this automatically + (maybe not possible to do it fully automatically?) + - for calls we could either inline (but should be done separately by the user) + or we could use some contract that replaces a call by an assignment + (should also be a separate tactic) + - for random samplings we could continue but we can also stop + + - When hitting a stop condition we can do a partial stop as well: + = throw out the variables that we cannot propagate and continue + with the rest + + TODO: + Cfold n -> Cfold cpos_range + Give propagate or preserve set to cfold? +*) + +let cfold_stmt ?(simplify = true) ?(eager = true) (pf, hyps) (me : memenv) (olen : int option) (zpr : Zpr.zipper) = let env = LDecl.toenv hyps in - let simplify : expr -> expr = + let e_simplify : expr -> expr = if simplify then (fun e -> let e = ss_inv_of_expr (fst me) e in let e = map_ss_inv1 (EcReduction.simplify EcReduction.nodelta hyps) e in @@ -196,82 +256,133 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( e ) else identity in - let for_instruction ((subst as subst0) : (expr, unit) Mpv.t) (i : instr) = - let wr = EcPV.i_write env i in - let i = Mpv.isubst env subst i in - - let (subst, asgn) = - List.fold_left_map (fun subst (pv, e) -> - let exception Remove in - - try - if PV.mem_pv env pv wr then raise Remove; - let rd = EcPV.e_read env e in - if PV.mem_pv env pv rd then raise Remove; - subst, None - - with Remove -> - Mpv.remove env pv subst, Some ((pv, e.e_ty), e) - ) subst (EcPV.Mnpv.bindings (Mpv.pvs subst)) in - - let asgn = List.filter_map identity asgn in - - let mk_asgn (lve : ((prog_var * ty) * expr) list) = - let lvs, es = List.split lve in - lv_of_list lvs - |> Option.map (fun lv -> i_asgn (lv, e_tuple es)) - |> Option.to_list in - - let exception Interrupt in - - try - let subst, aout = - let exception Default in - - try - match i.i_node with - | Sasgn (lv, e) -> - (* We already removed the variables of `lv` & the rhs from the substitution *) - (* We are only interested in the variables of `lv` that are in `wr` *) - let es = - match simplify e, lv with - | { e_node = Etuple es }, LvTuple _ -> es - | _, LvTuple _ -> raise Default - | e, _ -> [e] in - - let lv = lv_to_ty_list lv in - - let tosubst, asgn2 = List.partition (fun ((pv, _), _) -> - Mpv.mem env pv subst0 - ) (List.combine lv es) in - - let subst = - List.fold_left - (fun subst ((pv, _), e) -> Mpv.add env pv e subst) - subst tosubst in - - let asgn = - List.filter - (fun ((pv, _), _) -> not (Mpv.mem env pv subst)) - asgn in - - (subst, mk_asgn asgn @ mk_asgn asgn2) - - | Srnd _ -> - (subst, mk_asgn asgn @ [i]) - - | _ -> raise Default - - with Default -> - if List.exists - (fun (pv, _) -> Mpv.mem env pv subst0) - (fst (PV.elements wr)) - then raise Interrupt; - (subst, mk_asgn asgn @ [i]) - - in `Continue (subst, aout) + let i_simplify : instr -> instr = + if simplify then (fun i -> i) (* FIXME: get this to do something *) + else identity + in - with Interrupt -> `Interrupt + (* + for_instruction does the following: + Check if you can propagate across the given instruction as per + description above + If yes, do it, return the updates instructions (possibly none) + and update subst and preserve + If STOP return Interrupt + + EAGER MODE: + if we would fail by preserve -> + add move preserve to propagate and push + if we fail by if -> + add if to subst + *) + let for_instruction (subst, preserve: (expr, unit) Mpv.t * (PV.t Mnpv.t)) (i : instr) = + let esubst subst e = + EcPV.Mpv.esubst env subst e |> e_simplify + in + let isubst subst i = + EcPV.Mpv.isubst env subst i |> i_simplify + in + let is_preserved preserve pv = + Mnpv.exists (fun _ preserve -> EcPV.PV.mem_pv env pv preserve) preserve + in + let is_propagated subst pv = + Mnpv.contains (Mpv.pvs subst) pv + in + let preserved_pvs preserve = + Mnpv.bindings preserve |> List.snd |> List.fold_left (PV.union) PV.empty + in + let propagated_pvs subst = + (Mpv.pvs subst) |> Mnpv.bindings |> List.fst + in + let update_preserved preserve subst pv e = + let rd = EcPV.e_read env e in + let rd = List.fold_left (fun rd pv -> + EcPV.PV.remove env pv rd + ) rd (propagated_pvs subst) + in + PVMap.add pv rd preserve + in + let promote_preserved_to_propagated subst preserve pv (e:expr) = + let preserve = Mnpv.map (fun preserve -> + PV.remove env pv preserve + ) preserve + in + let subst = Mpv.add env pv e subst in + (subst, preserve) + in + let collect_assignments asgns : instr option = + match asgns with + | [] -> None + | (lv, e)::[] -> Some (i_asgn ((LvVar lv), e)) + | asgns -> + let lvs, es = List.split asgns in + let lv = LvTuple lvs in + let e = e_tuple es in + Some (i_asgn (lv, e)) + in + match i.i_node with + | Sasgn (lv, e) -> + let asgns = split_assignment lv e in + let exception Abort in + begin try + let (subst, preserve), asgns = List.fold_left_map (fun (subst, preserve) ((pv, t), e) -> + if is_preserved preserve pv then + if eager + then + let e = esubst subst e in + promote_preserved_to_propagated subst preserve pv e, None + else raise Abort + else + if not (is_propagated subst pv) then + (subst, preserve), Some ((pv, t), esubst subst e) + else + let e = esubst subst e in + let rd = EcPV.e_read env e in + (* We can propagate even if the expression + depends on what is being assigned + FIXME: should we remove all variables + being propagated? + *) + let rd = EcPV.PV.remove env pv rd in + (* FIXME: add case for eager *) + let preserve = Mnpv.add pv rd preserve in + let subst = Mpv.add env pv e subst in + (subst, preserve), None + ) (subst, preserve) asgns + in + let asgns = List.filter_map identity asgns in + `Continue ((subst, preserve), Option.to_list (collect_assignments asgns)) + with Abort -> `Interrupt + end + | Srnd (_, _) + | Scall (_, _, _) + | Swhile (_, _) + | Sif (_, _, _) + | Smatch (_, _) -> + let wr = EcPV.i_write env i in + let spvs = Mnpv.keys (Mpv.pvs subst) in + let ppvs = Mnpv.keys preserve in + if + let check = List.for_all (fun pv -> + not @@ EcPV.PV.mem_pv env pv wr) in + check spvs && check ppvs + then + `Continue ((subst, preserve), [isubst subst i]) + else + `Interrupt + | Sraise _ -> `Interrupt + | Sabstract id -> + let aus = EcEnv.AbsStmt.byid id env in + begin match aus with + | { aus_calls = []; aus_reads; aus_writes } -> + if List.for_all (fun (pv, _) -> + not ((is_propagated subst pv) || (is_preserved preserve pv)) + ) (aus_reads @ aus_writes) then + `Continue ((subst, preserve), [i]) + else + `Interrupt + | _ -> `Interrupt + end in let body, epilog = @@ -283,55 +394,65 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( tc_error pf "expecting at least %d instructions" olen; List.takedrop (olen+1) zpr.z_tail in - let lv, subst, body, rem = + let lv, (subst, preserve), body, rem = match body with | { i_node = Sasgn (lv, e) } :: is -> - let es = - match simplify e, lv with - | { e_node = Etuple es }, LvTuple _ -> es - | _, LvTuple _ -> - tc_error pf - "the left-value is a tuple but the right-hand expression \ - is not a tuple expression"; - | e, _ -> [e] in - let lv = lv_to_ty_list lv in - + let asgns = split_assignment lv e in + let lv = List.fst asgns in + if not (List.for_all (is_loc -| fst) lv) then tc_error pf "left-values must be made of local variables only"; + (* Variables in the domain of substs + are variables to be propagated *) let subst = List.fold_left (fun subst ((pv, _), e) -> Mpv.add env pv e subst) - Mpv.empty (List.combine lv es) in + Mpv.empty asgns in + + let preserve = + List.fold_left + (fun preserve ((pv, _), e) -> + Mnpv.add + pv + EcPV.(PV.remove env pv (e_read env e)) + preserve) + Mnpv.empty + asgns + in - let subst, is, rem = - List.fold_left_map_while for_instruction subst is in + let (subst, preserve), is, rem = + List.fold_left_map_while for_instruction (subst, preserve) is in - lv, subst, List.flatten is, rem + lv, (subst, preserve), List.flatten is, rem | _ -> tc_error pf "cannot find a left-value assignment at given position" in - let lv, es = - List.filter_map (fun ((pv, _) as pvty) -> - match Mpv.find env pv subst with - | e -> Some (pvty, e) - | exception Not_found -> None - ) lv |> List.split in + Format.eprintf "Instructions folded:@.%a@." EcPrinting.(pp_stmt PPEnv.(ofenv env)) (stmt body); + + + let asgns = Mnpv.bindings (Mpv.pvs subst) in + + let lv, es = List.map (fun (pv, e) -> + (pv, e_ty e), e) asgns |> List.split + in let asgn = lv_of_list lv |> Option.map (fun lv -> i_asgn (lv, e_tuple es)) |> Option.to_list in + Format.eprintf "Cfold assignments:@.%a@." EcPrinting.(pp_stmt PPEnv.(ofenv env)) (stmt asgn); + let zpr = { zpr with Zpr.z_tail = body @ asgn @ rem @ epilog } in (me, zpr, []) (* -------------------------------------------------------------------- *) -let t_cfold_r side cpos olen g = +let t_cfold_r side cpos olen eager g = let tr = fun side -> `Fold (side, cpos, olen) in - let cb = fun cenv _ me zpr -> cfold_stmt cenv me olen zpr in + let cb = fun cenv _ me zpr -> cfold_stmt ~eager cenv me olen zpr in t_code_transform side ~bdhoare:true cpos tr (t_zip cb) g (* -------------------------------------------------------------------- *) @@ -339,12 +460,12 @@ let t_kill = FApi.t_low3 "code-tx-kill" t_kill_r let t_alias = FApi.t_low3 "code-tx-alias" t_alias_r let t_set = FApi.t_low4 "code-tx-set" t_set_r let t_set_match = FApi.t_low4 "code-tx-set-match" t_set_match_r -let t_cfold = FApi.t_low3 "code-tx-cfold" t_cfold_r +let t_cfold = FApi.t_low4 "code-tx-cfold" t_cfold_r (* -------------------------------------------------------------------- *) -let process_cfold (side, cpos, olen) tc = +let process_cfold (side, cpos, olen, eager) tc = let cpos = EcLowPhlGoal.tc1_process_codepos tc (side, cpos) in - t_cfold side cpos olen tc + t_cfold side cpos olen eager tc let process_kill (side, cpos, len) tc = let cpos = EcLowPhlGoal.tc1_process_codepos tc (side, cpos) in diff --git a/src/phl/ecPhlCodeTx.mli b/src/phl/ecPhlCodeTx.mli index b1dab22744..6e79de8ec8 100644 --- a/src/phl/ecPhlCodeTx.mli +++ b/src/phl/ecPhlCodeTx.mli @@ -12,14 +12,14 @@ val t_kill : oside -> codepos -> int option -> backward val t_alias : oside -> codepos -> psymbol option -> backward val t_set : oside -> codepos -> bool * psymbol -> expr -> backward val t_set_match : oside -> codepos -> symbol -> unienv * mevmap * form -> backward -val t_cfold : oside -> codepos -> int option -> backward +val t_cfold : oside -> codepos -> int option -> bool -> backward (* -------------------------------------------------------------------- *) val process_kill : oside * pcodepos * int option -> backward val process_alias : oside * pcodepos * psymbol option -> backward val process_set : oside * pcodepos * bool * psymbol * pexpr -> backward val process_set_match : oside * pcodepos * psymbol * pformula -> backward -val process_cfold : oside * pcodepos * int option -> backward +val process_cfold : oside * pcodepos * int option * bool -> backward val process_case : oside * pcodepos -> backward (* -------------------------------------------------------------------- *) diff --git a/src/phl/ecPhlLoopTx.ml b/src/phl/ecPhlLoopTx.ml index 60aa5e71a6..8742401360 100644 --- a/src/phl/ecPhlLoopTx.ml +++ b/src/phl/ecPhlLoopTx.ml @@ -334,7 +334,7 @@ let process_unroll_for ~cfold side cpos tc = let cpos = EcMatching.Position.shift ~offset:(-1) cpos in let clen = blen * (List.length zs - 1) in - FApi.t_last (EcPhlCodeTx.t_cfold side cpos (Some clen)) tcenv + FApi.t_last (EcPhlCodeTx.t_cfold side cpos (Some clen) false) tcenv end else tcenv (* -------------------------------------------------------------------- *) diff --git a/tests/cfold.ec b/tests/cfold.ec index 3d6d435623..0bd44f036e 100644 --- a/tests/cfold.ec +++ b/tests/cfold.ec @@ -1,6 +1,68 @@ (* -------------------------------------------------------------------- *) require import AllCore Distr. +(* -------------------------------------------------------------------- *) +theory CfoldSelf. + module M = { + proc f(a : int, b : int) : int = { + var c : int; + var d : int; + + c <- c; + c <- c + 1; + c <- c + d; + d <- b + a; + c <- d; + if (a + b = c) { + c <- 0; + a <- c; + } else { + c <- 1; + b <- c; + } + return c; + } + }. + + lemma L : hoare[M.f : true ==> res = 0]. + proof. + proc. + cfold 1. + by auto => /> ?; apply addzC. + qed. +end CfoldSelf. + +(* -------------------------------------------------------------------- *) +theory CfoldStarSelf. + module M = { + proc f(a : int, b : int) : int = { + var c : int; + var d : int; + + c <- c; + c <- c + 1; + c <- c + d; + d <- b + a; + c <- d; + if (a + b = c) { + c <- 0; + a <- c; + } else { + c <- 1; + b <- c; + } + return c; + } + }. + + lemma L : hoare[M.f : true ==> res = 0]. + proof. + proc. + cfold* 1. + by auto => /> ?; apply addzC. + qed. +end CfoldStarSelf. + (* -------------------------------------------------------------------- *) theory CfoldStopIf. module M = {