Skip to content

Commit 49e112e

Browse files
committed
Implement cbv tactic
1 parent b26cc0a commit 49e112e

File tree

12 files changed

+344
-44
lines changed

12 files changed

+344
-44
lines changed

examples/cbv_tactic.me

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(* cbv_tactic.me *)
2+
(* Using the cbv tactic. *)
3+
4+
Inductive nat : Type := | O : nat | S : forall (_: nat), nat.
5+
6+
Check nat.
7+
Check (S O).
8+
Check nat_ind.
9+
10+
Fixpoint add (n : nat) (m : nat) {struct n} : nat := match n with | O => m | S n' => S (add n' m) end.
11+
12+
Theorem test1 : eq nat (add (S O) (O)) (S O).

src/commandlanguage/command_exec.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,10 @@ static int _handle_show_command(MEngineRuntime *rt, ShowCmd *show_cmd) {
832832
// Show proof term
833833
Expression *pending_theorem = proof_state_pending_theorem(rt->proof_state);
834834
Expression *pending_theorem_body = get_expression_body(pending_theorem);
835-
MPRINT(rt->options->quiet, stdout, HEADER "Proof Term:" CRESET "\n%s\n",
836-
stringify_expression(pending_theorem_body));
835+
Expression *pending_theorem_type = get_expression_type(pending_theorem);
836+
MPRINT(rt->options->quiet, stdout, HEADER "Proof Term:" CRESET "\n%s : %s\n",
837+
stringify_expression(pending_theorem_body),
838+
stringify_expression(pending_theorem_type));
837839
return 0;
838840
}
839841
case SHOW_KW_GOAL: {

src/engine/tactics.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#include "src/engine/tactics.h"
22

33
#include <stdlib.h>
4+
#include <string.h>
45

56
#include "src/common/doubly_linked_list.h"
67
#include "src/engine/rewrite_internal.h"
8+
#include "src/engine/tactic_api.h"
79
#include "src/engine/unify.h"
810
#include "src/kernel/expression.h"
11+
#include "src/kernel/normalize.h"
912
#include "src/kernel/subst.h"
1013
#include "src/runtime/core.h"
1114

@@ -326,4 +329,43 @@ TacticResult *erewrite_tactic(Expression *goal, Expression *lemma) {
326329
TacticResult *tac_result = init_tactic_result(true, new_goals, NULL);
327330
free_rewrite_result(rwr);
328331
return tac_result;
332+
}
333+
334+
TacticResult *cbv_tactic(Expression *goal, char **rules, int rule_count) {
335+
ReductionFlags flags = 0;
336+
if (!rules || rule_count == 0) {
337+
flags = REDUCE_ALL;
338+
} else {
339+
for (int i = 0; i < rule_count; i++) {
340+
if (strcmp(rules[i], "beta")) {
341+
flags |= REDUCE_BETA;
342+
} else if (strcmp(rules[i], "delta")) {
343+
flags |= REDUCE_DELTA;
344+
} else if (strcmp(rules[i], "iota")) {
345+
flags |= REDUCE_IOTA;
346+
} else if (strcmp(rules[i], "fix")) {
347+
flags |= REDUCE_FIX;
348+
} else {
349+
// unknown rule for now
350+
return init_tactic_result(false, NULL, "Unknown converison rule in cbv.");
351+
}
352+
}
353+
}
354+
355+
char *old_goal_name = get_hole_name(goal);
356+
Expression *old_goal_ty = get_expression_type(goal);
357+
Context *ctx = get_expression_context(goal);
358+
359+
Expression *new_goal_ty = normalize_cbv(old_goal_ty, flags);
360+
Expression *new_goal = init_hole_expression(old_goal_name, new_goal_ty, ctx);
361+
362+
if (!can_fill(goal, new_goal)) {
363+
return init_tactic_result(false, NULL, "Failed to fill the goal conversion");
364+
}
365+
366+
fill_hole(goal, new_goal);
367+
368+
DoublyLinkedList *new_goals = dll_create();
369+
dll_insert_at_tail(new_goals, dll_new_node(new_goal));
370+
return init_tactic_result(true, new_goals, NULL);
329371
}

src/engine/tactics.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,16 @@ TacticResult *rewrite_tactic(Expression *goal, Expression *lemma);
4040
// Same as rewrite_tactic, but allows unresolved binders to be added as goals to the proof state.
4141
TacticResult *erewrite_tactic(Expression *goal, Expression *lemma);
4242

43+
// Given a goal, and optionally a list of conversion rules, attempt to normalize the goal by
44+
// applying the set of conversion rules given in `rules`.
45+
//
46+
// Supported rules:
47+
// - beta: (fun x => t) u ~> t[x := u]
48+
// - delta: unfold definitions
49+
// - iota: match on constructors
50+
// - fix: fixpoint to lambdas
51+
//
52+
// If `rules` is NULL, then all rules are applied
53+
TacticResult *cbv_tactic(Expression *goal, char **rules, int rule_count);
54+
4355
#endif // NEW_TACTICS

src/kernel/definitional_equal.c

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "src/kernel/definitional_equal.h"
2+
3+
#include "src/common/linear_map.h"
4+
#include "src/kernel/normalize.h"
5+
6+
static bool _defeq(Expression *a, Expression *b, LinearMap *mapping) {
7+
// Reduce both sides to WHNF before inspecting structure
8+
a = normalize_whnf(a);
9+
b = normalize_whnf(b);
10+
11+
if (a == b) {
12+
return true;
13+
}
14+
15+
if (a->tag != b->tag) {
16+
return false;
17+
}
18+
19+
switch (a->tag) {
20+
case TYPE_EXPRESSION:
21+
case PROP_EXPRESSION:
22+
return true;
23+
24+
case VAR_EXPRESSION:
25+
case HOLE_EXPRESSION:
26+
return (a == b) || (linear_map_get(mapping, a) == b);
27+
28+
case APP_EXPRESSION:
29+
return _defeq(a->as.app.func, b->as.app.func, mapping) &&
30+
_defeq(a->as.app.arg, b->as.app.arg, mapping);
31+
32+
case FORALL_EXPRESSION:
33+
linear_map_set(mapping, a->as.forall.bound_variable, b->as.forall.bound_variable);
34+
return _defeq(a->as.forall.body, b->as.forall.body, mapping);
35+
36+
case LAMBDA_EXPRESSION:
37+
linear_map_set(mapping, a->as.lambda.bound_variable, b->as.lambda.bound_variable);
38+
return _defeq(a->as.lambda.body, b->as.lambda.body, mapping);
39+
40+
case MATCH_EXPRESSION: {
41+
if (!_defeq(a->as.match.scrutinee, b->as.match.scrutinee, mapping)) {
42+
return false;
43+
}
44+
45+
if (a->as.match.branch_count != b->as.match.branch_count) {
46+
return false;
47+
}
48+
49+
for (int i = 0; i < a->as.match.branch_count; i++) {
50+
MatchBranch *ba = a->as.match.branches[i];
51+
MatchBranch *bb = b->as.match.branches[i];
52+
53+
if (!_defeq(ba->constructor, bb->constructor, mapping)) {
54+
return false;
55+
}
56+
57+
if (ba->pattern_var_count != bb->pattern_var_count) {
58+
return false;
59+
}
60+
61+
for (int j = 0; j < ba->pattern_var_count; j++) {
62+
linear_map_set(mapping, ba->pattern_variables[j], bb->pattern_variables[j]);
63+
}
64+
65+
if (!_defeq(ba->body, bb->body, mapping)) {
66+
return false;
67+
}
68+
}
69+
return true;
70+
}
71+
72+
case FIX_EXPRESSION:
73+
linear_map_set(mapping, a->as.fix.recursive_var, b->as.fix.recursive_var);
74+
75+
if (a->as.fix.arg_count != b->as.fix.arg_count) {
76+
return false;
77+
}
78+
79+
if (a->as.fix.decreasing_arg_index != b->as.fix.decreasing_arg_index) {
80+
return false;
81+
}
82+
83+
for (int i = 0; i < a->as.fix.arg_count; i++) {
84+
linear_map_set(mapping, a->as.fix.args[i], b->as.fix.args[i]);
85+
}
86+
87+
return _defeq(a->as.fix.body, b->as.fix.body, mapping);
88+
}
89+
90+
return false;
91+
}
92+
93+
bool definitional_equal(Expression *a, Expression *b) {
94+
if (a == b) {
95+
return true;
96+
}
97+
98+
LinearMap *mapping = linear_map_new();
99+
bool result = _defeq(a, b, mapping);
100+
linear_map_clear_free(mapping);
101+
return result;
102+
}

src/kernel/definitional_equal.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef DEFINITIONAL_EQUAL_H
2+
#define DEFINITIONAL_EQUAL_H
3+
4+
#include "src/kernel/expression.h"
5+
6+
bool definitional_equal(Expression *a, Expression *b);
7+
8+
#endif // DEFINITIONAL_EQUAL_H

src/kernel/expression.c

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "src/common/linear_map.h"
99
#include "src/kernel/beta_reduction.h"
1010
#include "src/kernel/context.h"
11+
#include "src/kernel/definitional_equal.h"
1112
#include "src/kernel/inductive.h"
1213
#include "src/kernel/structural.h"
1314
#include "src/kernel/subst.h"
@@ -1328,7 +1329,7 @@ bool is_hole(Expression *expr) { return expr->tag == HOLE_EXPRESSION; }
13281329
// 3) Term does not itself contain the hole.
13291330
// This does no modifications/creates no new objects.
13301331
bool can_fill(Expression *hole, Expression *term) {
1331-
bool types_match = congruent_with_holes(get_expression_type(hole), get_expression_type(term));
1332+
bool types_match = definitional_equal(get_expression_type(hole), get_expression_type(term));
13321333
if (get_maybe_hole_free(term)) {
13331334
return types_match && valid_in_context(term, get_expression_context(hole));
13341335
}
@@ -1412,71 +1413,57 @@ void fill_hole(Expression *hole, Expression *term) {
14121413
return;
14131414
}
14141415

1415-
// check if term satisfies hole type...
1416-
LinearMap *alpha_equivalences = linear_map_new();
1417-
LinearMap *required_holes = linear_map_new();
1418-
bool types_match = _congruent_with_holes(get_expression_type(hole), get_expression_type(term),
1419-
alpha_equivalences, required_holes);
1416+
bool types_match = definitional_equal(get_expression_type(hole), get_expression_type(term));
14201417
if (!types_match) {
1421-
linear_map_clear_free(alpha_equivalences);
1422-
linear_map_clear_free(required_holes);
14231418
return; // Todo: signal that this has failed?
14241419
}
1425-
1426-
int n = required_holes->size;
1427-
for (int i = 0; i < n; i++) {
1428-
Expression *hole = (required_holes->items + i)->key;
1429-
Expression *substitute = (required_holes->items + i)->val;
1430-
fill_hole(hole, substitute);
1431-
}
1432-
14331420
DoublyLinkedList *holepars = hole->uplinks;
14341421
for (int i = 0; i < dll_len(holepars); i++) {
14351422
Uplink *uplink = dll_at(holepars, i)->data;
14361423
switch (uplink->relation) {
14371424
case (LAMBDA_BODY): {
14381425
Expression *ptr = (Expression *)uplink->ptr;
1439-
ptr->as.lambda.body = term;
1426+
SET_LAMBDA_BODY(ptr, term);
14401427
break;
14411428
}
14421429
case (LAMBDA_BOUND_VAR): {
14431430
Expression *ptr = (Expression *)uplink->ptr;
1444-
ptr->as.lambda.bound_variable = term;
1431+
SET_LAMBDA_BOUND_VAR(ptr, term);
14451432
break;
14461433
}
14471434
case (APP_FUNC): {
14481435
Expression *ptr = (Expression *)uplink->ptr;
1449-
ptr->as.app.func = term;
1436+
SET_APP_FUNC(ptr, term);
14501437
break;
14511438
}
14521439
case (APP_ARG): {
14531440
Expression *ptr = (Expression *)uplink->ptr;
1454-
ptr->as.app.arg = term;
1441+
SET_APP_ARG(ptr, term);
14551442
break;
14561443
}
14571444
case (FORALL_BODY): {
14581445
Expression *ptr = (Expression *)uplink->ptr;
1459-
ptr->as.forall.body = term;
1446+
SET_FORALL_BODY(ptr, term);
14601447
break;
14611448
}
14621449
case (FORALL_BOUND_VAR): {
14631450
Expression *ptr = (Expression *)uplink->ptr;
1464-
ptr->as.forall.bound_variable = term;
1451+
SET_FORALL_BOUND_VAR(ptr, term);
14651452
break;
14661453
}
14671454
case (VAR_BODY): {
14681455
Expression *ptr = (Expression *)uplink->ptr;
1469-
ptr->as.var.body = term;
1456+
SET_VAR_BODY(ptr, term);
14701457
break;
14711458
}
14721459
case (EXPR_TYPE): {
14731460
Expression *ptr = (Expression *)uplink->ptr;
1474-
ptr->type = term;
1461+
SET_EXPR_TYPE(ptr, term);
14751462
break;
14761463
}
14771464
case (EXPR_CONTEXT): {
14781465
Expression *ptr = (Expression *)uplink->ptr;
1479-
ptr->context = term;
1466+
SET_EXPR_CONTEXT(ptr, term);
14801467
break;
14811468
}
14821469
default:
@@ -1485,9 +1472,6 @@ void fill_hole(Expression *hole, Expression *term) {
14851472
break;
14861473
}
14871474
}
1488-
1489-
linear_map_clear_free(alpha_equivalences);
1490-
linear_map_clear_free(required_holes);
14911475
}
14921476

14931477
char c_counter = 'a';

0 commit comments

Comments
 (0)