Skip to content

Commit 670f052

Browse files
committed
get test_symbolic_primitives working again
1 parent 14aac5b commit 670f052

File tree

1 file changed

+61
-101
lines changed

1 file changed

+61
-101
lines changed

tests/test_symbolic_primitives.py

Lines changed: 61 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import numpy
22
import jax
3-
from neurallogic import symbolic_primitives
3+
from neurallogic import symbolic_primitives, sym_gen
44

55

66
def test_unary_operator_str():
77
output = symbolic_primitives.unary_operator("not", "True")
88
expected = "not(True)"
99
assert output == expected
10-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
11-
eval_expected = symbolic_primitives.eval_symbolic_expression(expected)
10+
eval_output = sym_gen.eval_symbolic_expression(output)
11+
eval_expected = sym_gen.eval_symbolic_expression(expected)
1212
assert eval_output == eval_expected
1313

1414

@@ -17,8 +17,8 @@ def test_unary_operator_vector():
1717
output = symbolic_primitives.unary_operator("not", x)
1818
expected = numpy.array(["not(True)", "not(False)"])
1919
assert numpy.array_equal(output, expected)
20-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
21-
eval_expected = symbolic_primitives.eval_symbolic_expression(expected)
20+
eval_output = sym_gen.eval_symbolic_expression(output)
21+
eval_expected = sym_gen.eval_symbolic_expression(expected)
2222
assert numpy.array_equal(eval_output, eval_expected)
2323

2424

@@ -28,18 +28,18 @@ def test_unary_operator_matrix():
2828
expected = numpy.array(
2929
[["not(True)", "not(False)"], ["not(False)", "not(True)"]])
3030
assert numpy.array_equal(output, expected)
31-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
32-
eval_expected = symbolic_primitives.eval_symbolic_expression(expected)
31+
eval_output = sym_gen.eval_symbolic_expression(output)
32+
eval_expected = sym_gen.eval_symbolic_expression(expected)
3333
assert numpy.array_equal(eval_output, eval_expected)
3434

3535

3636
def test_binary_operator_str_str():
3737
output = symbolic_primitives.binary_infix_operator("+", "1", "2")
3838
expected = "1 + 2"
3939
assert output == expected
40-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
41-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
42-
"1"), symbolic_primitives.eval_symbolic_expression("2"))
40+
eval_output = sym_gen.eval_symbolic_expression(output)
41+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
42+
"1"), sym_gen.eval_symbolic_expression("2"))
4343
assert numpy.array_equal(eval_output, numpy_output)
4444

4545

@@ -49,9 +49,9 @@ def test_binary_operator_vector_vector():
4949
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
5050
expected = numpy.array(["1 + 3", "2 + 4"])
5151
assert numpy.array_equal(output, expected)
52-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
53-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
54-
x1), symbolic_primitives.eval_symbolic_expression(x2))
52+
eval_output = sym_gen.eval_symbolic_expression(output)
53+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
54+
x1), sym_gen.eval_symbolic_expression(x2))
5555
assert numpy.array_equal(eval_output, numpy_output)
5656

5757

@@ -61,9 +61,9 @@ def test_binary_operator_matrix_vector():
6161
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
6262
expected = numpy.array([["1 + 5", "2 + 6"], ["3 + 5", "4 + 6"]])
6363
assert numpy.array_equal(output, expected)
64-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
65-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
66-
x1), symbolic_primitives.eval_symbolic_expression(x2))
64+
eval_output = sym_gen.eval_symbolic_expression(output)
65+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
66+
x1), sym_gen.eval_symbolic_expression(x2))
6767
assert numpy.array_equal(eval_output, numpy_output)
6868

6969

@@ -73,9 +73,9 @@ def test_binary_operator_vector_matrix():
7373
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
7474
expected = numpy.array([["1 + 3", "2 + 4"], ["1 + 5", "2 + 6"]])
7575
assert numpy.array_equal(output, expected)
76-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
77-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
78-
x1), symbolic_primitives.eval_symbolic_expression(x2))
76+
eval_output = sym_gen.eval_symbolic_expression(output)
77+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
78+
x1), sym_gen.eval_symbolic_expression(x2))
7979
assert numpy.array_equal(eval_output, numpy_output)
8080

8181

@@ -85,9 +85,9 @@ def test_binary_operator_matrix_matrix():
8585
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
8686
expected = numpy.array([["1 + 5", "2 + 6"], ["3 + 7", "4 + 8"]])
8787
assert numpy.array_equal(output, expected)
88-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
89-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
90-
x1), symbolic_primitives.eval_symbolic_expression(x2))
88+
eval_output = sym_gen.eval_symbolic_expression(output)
89+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
90+
x1), sym_gen.eval_symbolic_expression(x2))
9191
assert numpy.array_equal(eval_output, numpy_output)
9292

9393

@@ -100,9 +100,9 @@ def test_binary_operator_matrix_matrix_2():
100100
expected = numpy.array(
101101
[["1 + 5", "2 + 6", "3 + 7", "4 + 8"] for _ in range(10)])
102102
assert numpy.array_equal(output, expected)
103-
eval_output = symbolic_primitives.eval_symbolic_expression(output)
104-
numpy_output = numpy.add(symbolic_primitives.eval_symbolic_expression(
105-
x1), symbolic_primitives.eval_symbolic_expression(x2))
103+
eval_output = sym_gen.eval_symbolic_expression(output)
104+
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
105+
x1), sym_gen.eval_symbolic_expression(x2))
106106
assert numpy.array_equal(eval_output, numpy_output)
107107

108108

@@ -139,76 +139,36 @@ def test_to_boolean_value_string():
139139
assert output == expected
140140

141141

142-
def test_to_boolean_symbolic_values():
143-
output = symbolic_primitives.make_symbolic([1, 1])
144-
expected = ["True", "True"]
145-
assert numpy.array_equal(output, expected)
146-
output = symbolic_primitives.make_symbolic([0, 0])
147-
expected = ["False", "False"]
148-
assert numpy.array_equal(output, expected)
149-
output = symbolic_primitives.make_symbolic([True, False])
150-
expected = ["True", "False"]
151-
assert numpy.array_equal(output, expected)
152-
output = symbolic_primitives.make_symbolic([False, True])
153-
expected = ["False", "True"]
154-
assert numpy.array_equal(output, expected)
155-
output = symbolic_primitives.make_symbolic([1.0, 1.0])
156-
expected = ["True", "True"]
157-
assert numpy.array_equal(output, expected)
158-
output = symbolic_primitives.make_symbolic([0.0, 0.0])
159-
expected = ["False", "False"]
160-
assert numpy.array_equal(output, expected)
161-
output = symbolic_primitives.make_symbolic([[1, 1], [1, 1]])
162-
expected = [["True", "True"], ["True", "True"]]
163-
assert numpy.array_equal(output, expected)
164-
output = symbolic_primitives.make_symbolic([[0, 0], [0, 0]])
165-
expected = [["False", "False"], ["False", "False"]]
166-
assert numpy.array_equal(output, expected)
167-
output = symbolic_primitives.make_symbolic(
168-
[[True, False], [False, True]])
169-
expected = [["True", "False"], ["False", "True"]]
170-
assert numpy.array_equal(output, expected)
171-
output = symbolic_primitives.make_symbolic(
172-
[[[1, 0, 1], [1, 0, 1]], [[1, 0, 0], [1, 0, 0]]])
173-
expected = [[["True", "False", "True"], ["True", "False", "True"]], [
174-
["True", "False", "False"], ["True", "False", "False"]]]
175-
assert numpy.array_equal(output, expected)
176-
output = symbolic_primitives.make_symbolic(
177-
[[[1, "f", 1], [1, "g", 1]], [[1, "h", 0], [1, "f", 0]]])
178-
expected = [[["True", "f", "True"], ["True", "g", "True"]], [
179-
["True", "h", "False"], ["True", "f", "False"]]]
180-
assert numpy.array_equal(output, expected)
181-
182142

183143
def test_symbolic_eval():
184-
output = symbolic_primitives.eval_symbolic_expression("1 + 2")
144+
output = sym_gen.eval_symbolic_expression("1 + 2")
185145
expected = 3
186146
assert output == expected
187-
output = symbolic_primitives.eval_symbolic_expression("[1, 2, 3]")
147+
output = sym_gen.eval_symbolic_expression("[1, 2, 3]")
188148
expected = [1, 2, 3]
189149
assert numpy.array_equal(output, expected)
190-
output = symbolic_primitives.eval_symbolic_expression("[1, 2, 3] + [4, 5, 6]")
150+
output = sym_gen.eval_symbolic_expression("[1, 2, 3] + [4, 5, 6]")
191151
expected = [1, 2, 3, 4, 5, 6]
192152
assert numpy.array_equal(output, expected)
193-
output = symbolic_primitives.eval_symbolic_expression(['1', '2', '3'])
153+
output = sym_gen.eval_symbolic_expression(['1', '2', '3'])
194154
expected = [1, 2, 3]
195155
assert numpy.array_equal(output, expected)
196-
output = symbolic_primitives.eval_symbolic_expression(
156+
output = sym_gen.eval_symbolic_expression(
197157
['1', '2', '3'] + ['4', '5', '6'])
198158
expected = [1, 2, 3, 4, 5, 6]
199159
assert numpy.array_equal(output, expected)
200-
output = symbolic_primitives.eval_symbolic_expression(['not(False)', 'not(True)'])
160+
output = sym_gen.eval_symbolic_expression(['not(False)', 'not(True)'])
201161
expected = [True, False]
202162
assert numpy.array_equal(output, expected)
203-
output = symbolic_primitives.eval_symbolic_expression(
163+
output = sym_gen.eval_symbolic_expression(
204164
[['not(False)', 'not(True)'] + ['not(False)', 'not(True)']])
205165
expected = [[True, False, True, False]]
206166
assert numpy.array_equal(output, expected)
207-
output = symbolic_primitives.eval_symbolic_expression(numpy.array(
167+
output = sym_gen.eval_symbolic_expression(numpy.array(
208168
[['not(False)', 'not(True)'] + ['not(False)', 'not(True)']]))
209169
expected = [[True, False, True, False]]
210170
assert numpy.array_equal(output, expected)
211-
output = symbolic_primitives.eval_symbolic_expression(numpy.array(
171+
output = sym_gen.eval_symbolic_expression(numpy.array(
212172
[['not(False)', False], ['not(False)', 'not(True)']]))
213173
expected = [[True, False], [True, False]]
214174
assert numpy.array_equal(output, expected)
@@ -219,7 +179,7 @@ def test_symbolic_not():
219179
output = symbolic_primitives.symbolic_not(x1)
220180
expected = numpy.array([False, True])
221181
assert numpy.array_equal(output, expected)
222-
x1 = symbolic_primitives.make_symbolic(x1)
182+
x1 = sym_gen.make_symbolic(x1)
223183
output = symbolic_primitives.symbolic_not(x1)
224184
expected = numpy.array(["not(True)", "not(False)"])
225185
assert numpy.array_equal(output, expected)
@@ -231,8 +191,8 @@ def test_symbolic_and():
231191
output = symbolic_primitives.symbolic_and(x1, x2)
232192
expected = numpy.array([True, False])
233193
assert numpy.array_equal(output, expected)
234-
x1 = symbolic_primitives.make_symbolic(x1)
235-
x2 = symbolic_primitives.make_symbolic(x2)
194+
x1 = sym_gen.make_symbolic(x1)
195+
x2 = sym_gen.make_symbolic(x2)
236196
output = symbolic_primitives.symbolic_and(x1, x2)
237197
expected = numpy.array(["True and True", "False and True"])
238198
assert numpy.array_equal(output, expected)
@@ -244,10 +204,10 @@ def test_symbolic_xor():
244204
output = symbolic_primitives.symbolic_xor(x1, x2)
245205
expected = numpy.array([False, True])
246206
assert numpy.array_equal(output, expected)
247-
x1 = symbolic_primitives.make_symbolic(x1)
248-
x2 = symbolic_primitives.make_symbolic(x2)
207+
x1 = sym_gen.make_symbolic(x1)
208+
x2 = sym_gen.make_symbolic(x2)
249209
output = symbolic_primitives.symbolic_xor(x1, x2)
250-
expected = numpy.array(["(True) ^ (True)", "(False) ^ (True)"])
210+
expected = numpy.array(["True ^ True", "False ^ True"])
251211
assert numpy.array_equal(output, expected)
252212

253213

@@ -282,40 +242,40 @@ def symbolic_reduce_or_impl(input, expected, symbolic_expected, axes):
282242
expected = numpy.array(expected)
283243
assert numpy.array_equal(output, expected)
284244
# Test symbolic implementation
285-
input = symbolic_primitives.make_symbolic(input)
245+
input = sym_gen.make_symbolic(input)
286246
output = symbolic_primitives.symbolic_reduce_or(input, axes=axes)
287247
symbolic_expected = numpy.array(symbolic_expected)
288248
assert numpy.array_equal(output, symbolic_expected)
289249
# Compare the reference and symbolic evaluation
290-
symbolic_expected = symbolic_primitives.eval_symbolic_expression(symbolic_expected)
250+
symbolic_expected = sym_gen.eval_symbolic_expression(symbolic_expected)
291251
assert numpy.array_equal(expected, symbolic_expected)
292252

293253

294254
def test_symbolic_reduce_or():
295255
# Test 1: 2D matrix with different axes inputs
296256
symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=[
297-
True, True], symbolic_expected=['False or True or False', 'False or True or False'], axes=(1,))
257+
True, True], symbolic_expected=['((False or True) or False)', '((False or True) or False)'], axes=(1,))
298258
symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=[
299-
True, False], symbolic_expected=['False or True or True', 'False or False or False'], axes=(0,))
259+
True, False], symbolic_expected=['((False or True) or True)', '((False or False) or False)'], axes=(0,))
300260
symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=True,
301-
symbolic_expected='False or True or False or True or False', axes=(0, 1))
261+
symbolic_expected='((((False or True) or False) or True) or False)', axes=(0, 1))
302262
# Test 2: 3D matrix with different axes inputs
303-
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=[
304-
'False or True or False or True or False', 'False or True or False or True or False'], axes=(1, 2))
305-
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=[
306-
'False or True or False or True or False', 'False or True or False or True or False'], axes=(0, 2))
307-
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, False], symbolic_expected=[
308-
'False or True or True or True or True', 'False or False or False or False or False'], axes=(0, 1))
263+
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=
264+
['((((False or True) or False) or True) or False)', '((((False or True) or False) or True) or False)'], axes=(1, 2))
265+
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=
266+
['((((False or True) or False) or True) or False)', '((((False or True) or False) or True) or False)'], axes=(0, 2))
267+
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, False], symbolic_expected=
268+
['((((False or True) or True) or True) or True)', '((((False or False) or False) or False) or False)'], axes=(0, 1))
309269
symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=True,
310-
symbolic_expected='False or True or False or True or False or True or False or True or False', axes=(0, 1, 2))
270+
symbolic_expected='((((((((False or True) or False) or True) or False) or True) or False) or True) or False)', axes=(0, 1, 2))
311271
# Test 3: 4D matrix with different axes inputs
312-
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=[
313-
'False or True or False or True or False or True or False or True or False', 'False or True or False or True or False or True or False or True or False'], axes=(1, 2, 3))
314-
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=[
315-
'False or True or False or True or False or True or False or True or False', 'False or True or False or True or False or True or False or True or False'], axes=(0, 2, 3))
316-
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=[
317-
'False or True or False or True or False or True or False or True or False', 'False or True or False or True or False or True or False or True or False'], axes=(0, 1, 3))
318-
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, False], symbolic_expected=[
319-
'False or True or True or True or True or True or True or True or True', 'False or False or False or False or False or False or False or False or False'], axes=(0, 1, 2))
272+
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=
273+
['((((((((False or True) or False) or True) or False) or True) or False) or True) or False)', '((((((((False or True) or False) or True) or False) or True) or False) or True) or False)'], axes=(1, 2, 3))
274+
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=
275+
['((((((((False or True) or False) or True) or False) or True) or False) or True) or False)', '((((((((False or True) or False) or True) or False) or True) or False) or True) or False)'], axes=(0, 2, 3))
276+
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=
277+
['((((((((False or True) or False) or True) or False) or True) or False) or True) or False)', '((((((((False or True) or False) or True) or False) or True) or False) or True) or False)'], axes=(0, 1, 3))
278+
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, False], symbolic_expected=
279+
['((((((((False or True) or True) or True) or True) or True) or True) or True) or True)', '((((((((False or False) or False) or False) or False) or False) or False) or False) or False)'], axes=(0, 1, 2))
320280
symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=True,
321-
symbolic_expected='False or True or False or True or False or True or False or True or False or True or False or True or False or True or False or True or False', axes=(0, 1, 2, 3))
281+
symbolic_expected='((((((((((((((((False or True) or False) or True) or False) or True) or False) or True) or False) or True) or False) or True) or False) or True) or False) or True) or False)', axes=(0, 1, 2, 3))

0 commit comments

Comments
 (0)