Skip to content

Commit 41d3de4

Browse files
authored
Simplify Implied Weaker Conjuncts (#188)
1 parent ec170de commit 41d3de4

File tree

3 files changed

+162
-33
lines changed

3 files changed

+162
-33
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package liquidjava.rj_language.opt;
22

3+
import liquidjava.processor.context.Context;
4+
import liquidjava.rj_language.Predicate;
35
import java.util.Map;
46

57
import liquidjava.processor.facade.AliasDTO;
@@ -11,6 +13,8 @@
1113
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
1214
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1315
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
16+
import liquidjava.smt.SMTEvaluator;
17+
import liquidjava.smt.SMTResult;
1418

1519
public class ExpressionSimplifier {
1620

@@ -90,6 +94,16 @@ private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode nod
9094
return leftSimplified;
9195
}
9296

97+
// remove weaker conjuncts (e.g. x > 0 && x > -1 => x > 0)
98+
if (implies(leftSimplified.getValue(), rightSimplified.getValue())) {
99+
return new ValDerivationNode(leftSimplified.getValue(),
100+
new BinaryDerivationNode(leftSimplified, rightSimplified, "&&"));
101+
}
102+
if (implies(rightSimplified.getValue(), leftSimplified.getValue())) {
103+
return new ValDerivationNode(rightSimplified.getValue(),
104+
new BinaryDerivationNode(leftSimplified, rightSimplified, "&&"));
105+
}
106+
93107
// return the conjunction with simplified children
94108
Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue());
95109
// only create origin if at least one child has a meaningful origin
@@ -191,4 +205,17 @@ private static ValDerivationNode unwrapBooleanLiterals(ValDerivationNode node) {
191205

192206
return node;
193207
}
208+
209+
/**
210+
* Checks whether one expression implies another by asking Z3, used to remove weaker conjuncts in the simplification
211+
*/
212+
private static boolean implies(Expression stronger, Expression weaker) {
213+
try {
214+
SMTResult result = new SMTEvaluator().verifySubtype(new Predicate(stronger), new Predicate(weaker),
215+
Context.getInstance());
216+
return result.isOk();
217+
} catch (Exception e) {
218+
return false;
219+
}
220+
}
194221
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package liquidjava.rj_language.opt;
22

33
import static org.junit.jupiter.api.Assertions.*;
4+
import static liquidjava.utils.TestUtils.*;
45

56
import java.util.List;
67
import java.util.Map;
@@ -15,7 +16,6 @@
1516
import liquidjava.rj_language.ast.UnaryExpression;
1617
import liquidjava.rj_language.ast.Var;
1718
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
18-
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
1919
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
2020
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
2121
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
@@ -1020,37 +1020,78 @@ void testTwoArgAliasWithNormalExpression() {
10201020
assertNull(rightNode.getOrigin());
10211021
}
10221022

1023-
/**
1024-
* Helper method to compare two derivation nodes recursively
1025-
*/
1026-
private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) {
1027-
if (expected == null && actual == null)
1028-
return;
1029-
1030-
assertNotNull(expected);
1031-
assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match");
1032-
if (expected instanceof ValDerivationNode expectedVal) {
1033-
ValDerivationNode actualVal = (ValDerivationNode) actual;
1034-
assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(),
1035-
message + ": values should match");
1036-
assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin");
1037-
} else if (expected instanceof BinaryDerivationNode expectedBin) {
1038-
BinaryDerivationNode actualBin = (BinaryDerivationNode) actual;
1039-
assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match");
1040-
assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left");
1041-
assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right");
1042-
} else if (expected instanceof VarDerivationNode expectedVar) {
1043-
VarDerivationNode actualVar = (VarDerivationNode) actual;
1044-
assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match");
1045-
} else if (expected instanceof UnaryDerivationNode expectedUnary) {
1046-
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
1047-
assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
1048-
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
1049-
} else if (expected instanceof IteDerivationNode expectedIte) {
1050-
IteDerivationNode actualIte = (IteDerivationNode) actual;
1051-
assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition");
1052-
assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then");
1053-
assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else");
1054-
}
1023+
@Test
1024+
void testEntailedConjunctIsRemovedButOriginIsPreserved() {
1025+
// Given: b >= 100 && b > 0
1026+
// Expected: b >= 100 (b >= 100 implies b > 0)
1027+
1028+
addIntVariableToContext("b");
1029+
Expression b = new Var("b");
1030+
Expression bGe100 = new BinaryExpression(b, ">=", new LiteralInt(100));
1031+
Expression bGt0 = new BinaryExpression(b, ">", new LiteralInt(0));
1032+
Expression fullExpression = new BinaryExpression(bGe100, "&&", bGt0);
1033+
1034+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
1035+
1036+
assertNotNull(result);
1037+
assertEquals("b >= 100", result.getValue().toString(),
1038+
"The weaker conjunct should be removed when implied by the stronger one");
1039+
1040+
ValDerivationNode expectedLeft = new ValDerivationNode(bGe100, null);
1041+
ValDerivationNode expectedRight = new ValDerivationNode(bGt0, null);
1042+
ValDerivationNode expected = new ValDerivationNode(bGe100,
1043+
new BinaryDerivationNode(expectedLeft, expectedRight, "&&"));
1044+
1045+
assertDerivationEquals(expected, result, "Entailment simplification should preserve conjunction origin");
1046+
}
1047+
1048+
@Test
1049+
void testStrictComparisonImpliesNonStrictComparison() {
1050+
// Given: x > y && x >= y
1051+
// Expected: x > y (x > y implies x >= y)
1052+
1053+
addIntVariableToContext("x");
1054+
addIntVariableToContext("y");
1055+
Expression x = new Var("x");
1056+
Expression y = new Var("y");
1057+
Expression xGtY = new BinaryExpression(x, ">", y);
1058+
Expression xGeY = new BinaryExpression(x, ">=", y);
1059+
Expression fullExpression = new BinaryExpression(xGtY, "&&", xGeY);
1060+
1061+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
1062+
1063+
assertNotNull(result);
1064+
assertEquals("x > y", result.getValue().toString(),
1065+
"The stricter comparison should be kept when it implies the weaker one");
1066+
1067+
ValDerivationNode expectedLeft = new ValDerivationNode(xGtY, null);
1068+
ValDerivationNode expectedRight = new ValDerivationNode(xGeY, null);
1069+
ValDerivationNode expected = new ValDerivationNode(xGtY,
1070+
new BinaryDerivationNode(expectedLeft, expectedRight, "&&"));
1071+
1072+
assertDerivationEquals(expected, result, "Strict comparison simplification should preserve conjunction origin");
1073+
}
1074+
1075+
@Test
1076+
void testEquivalentBoundsKeepOneSide() {
1077+
// Given: i >= 0 && 0 <= i
1078+
// Expected: 0 <= i (both conjuncts express the same condition)
1079+
addIntVariableToContext("i");
1080+
Expression i = new Var("i");
1081+
Expression zeroLeI = new BinaryExpression(new LiteralInt(0), "<=", i);
1082+
Expression iGeZero = new BinaryExpression(i, ">=", new LiteralInt(0));
1083+
Expression fullExpression = new BinaryExpression(zeroLeI, "&&", iGeZero);
1084+
1085+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
1086+
1087+
assertNotNull(result);
1088+
assertEquals("0 <= i", result.getValue().toString(), "Equivalent bounds should collapse to a single conjunct");
1089+
1090+
ValDerivationNode expectedLeft = new ValDerivationNode(zeroLeI, null);
1091+
ValDerivationNode expectedRight = new ValDerivationNode(iGeZero, null);
1092+
ValDerivationNode expected = new ValDerivationNode(zeroLeI,
1093+
new BinaryDerivationNode(expectedLeft, expectedRight, "&&"));
1094+
1095+
assertDerivationEquals(expected, result, "Equivalent bounds simplification should preserve conjunction origin");
10551096
}
10561097
}

liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
package liquidjava.utils;
22

3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNotNull;
5+
36
import java.io.IOException;
47
import java.nio.file.Files;
58
import java.nio.file.Path;
69
import java.util.Optional;
710
import java.util.stream.Stream;
811

12+
import liquidjava.processor.context.Context;
13+
import liquidjava.rj_language.Predicate;
14+
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
15+
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
16+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
17+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
18+
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
19+
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
20+
import spoon.Launcher;
21+
import spoon.reflect.factory.Factory;
22+
923
public class TestUtils {
1024

25+
private final static Factory factory = new Launcher().getFactory();
26+
private final static Context context = Context.getInstance();
27+
1128
/**
1229
* Determines if the given path indicates that the test should pass
1330
*
@@ -64,4 +81,48 @@ public static Optional<String> getExpectedErrorFromDirectory(Path dirPath) {
6481
}
6582
return Optional.empty();
6683
}
84+
85+
/**
86+
* Helper method to compare two derivation nodes recursively
87+
*/
88+
public static void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) {
89+
if (expected == null && actual == null)
90+
return;
91+
92+
assertNotNull(expected);
93+
assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match");
94+
if (expected instanceof ValDerivationNode expectedVal) {
95+
ValDerivationNode actualVal = (ValDerivationNode) actual;
96+
assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(),
97+
message + ": values should match");
98+
assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin");
99+
} else if (expected instanceof BinaryDerivationNode expectedBin) {
100+
BinaryDerivationNode actualBin = (BinaryDerivationNode) actual;
101+
assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match");
102+
assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left");
103+
assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right");
104+
} else if (expected instanceof VarDerivationNode expectedVar) {
105+
VarDerivationNode actualVar = (VarDerivationNode) actual;
106+
assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match");
107+
} else if (expected instanceof UnaryDerivationNode expectedUnary) {
108+
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
109+
assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
110+
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
111+
} else if (expected instanceof IteDerivationNode expectedIte) {
112+
IteDerivationNode actualIte = (IteDerivationNode) actual;
113+
assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition");
114+
assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then");
115+
assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else");
116+
}
117+
}
118+
119+
/**
120+
* Helper method to add an integer variable to the context Needed for tests that rely on the SMT-based implication
121+
* checks The simplifier asks Z3 whether one conjunct implies another, so every variable in those expressions must
122+
* be in the context
123+
*/
124+
public static void addIntVariableToContext(String name) {
125+
context.addVarToContext(name, factory.Type().INTEGER_PRIMITIVE, new Predicate(),
126+
factory.Code().createCodeSnippetStatement(""));
127+
}
67128
}

0 commit comments

Comments
 (0)