(* fast-expon.ec *) (* definition and proof of correctness of fast exponentiation *) prover quorum=2 ["Alt-Ergo" "Z3"]. require import AllCore IntDiv. (* search (^). (* integer exponentiation *) lemma expr0 (x : int) : x ^ 0 = 1. lemma expr1 (x : int) : x ^ 1 = x. lemma exprS (x i : int) : 0 <= i => x ^ (i + 1) = x * x ^ i. lemma expr2 (x : int) : x ^ 2 = x * x. lemma exprM (x m n : int) : x ^ (m * n) = x ^ m ^ n. *) (* define even, odd : int -> bool *) op even (k : int) = k %% 2 = 0. (* %% is mod *) op odd (k : int) = ! even k. lemma even_eq (k : int) : even k => k = 2 * (k %/ 2). (* %/ is integer division *) proof. smt(). qed. lemma odd_eq (k : int) : odd k => k = 2 * ((k - 1) %/ 2) + 1. proof. smt(). qed. lemma fastexp_even (n k : int) : 1 < k => even k => n ^ k = (n * n) ^ (k %/ 2). proof. (* smt won't solve this! *) move => gt1_k even_k. by rewrite {1}(even_eq k) // exprM // expr2. qed. lemma fastexp_odd (n k : int) : 1 < k => odd k => n ^ k = (n * n) ^ ((k - 1) %/ 2) * n. proof. (* smt won't solve this! *) move => gt1_k odd_k. by rewrite {1}(odd_eq k) // exprS 1:/# exprM expr2 mulzC. qed. module M = { proc f(n : int, k : int) : int = { var n' : int; var k' : int; var r : int; n' <- n; k' <- k; r <- 1; if (0 < k') { while (1 < k') { if (even k') { n' <- n' * n'; k' <- k' %/ 2; } else { r <- n' * r; n' <- n' * n'; k' <- (k' - 1) %/ 2; } } r <- n' * r; } return r; } }. lemma correct (n_ k_ : int) : hoare[M.f : n = n_ /\ k = k_ /\ 0 <= k ==> res = n_ ^ k_]. proof. proc; simplify. seq 3 : (n' = n_ /\ k' = k_ /\ r = 1 /\ 0 <= k'). wp; skip; trivial. if. wp. while (0 < k' /\ n' ^ k' * r = n_ ^ k_). if. wp; skip; progress. smt(). smt(fastexp_even). wp; skip; progress. smt(). smt(fastexp_odd). skip; progress. smt(expr1). skip; progress. smt(expr0). qed.