9
9
from tests import utils
10
10
11
11
12
-
13
12
def test_include ():
14
13
test_data = [
15
14
[[1.0 , 1.0 ], 1.0 ],
@@ -23,7 +22,7 @@ def test_include():
23
22
]
24
23
for input , expected in test_data :
25
24
utils .check_consistency (hard_and .soft_and_include , hard_and .hard_and_include ,
26
- expected , input [0 ], input [1 ])
25
+ expected , input [0 ], input [1 ])
27
26
28
27
29
28
def test_neuron ():
@@ -43,7 +42,7 @@ def hard(weights, input):
43
42
return hard_and .hard_and_neuron (weights , input )
44
43
45
44
utils .check_consistency (soft , hard , expected ,
46
- jax .numpy .array (weights ), jax .numpy .array (input ))
45
+ jax .numpy .array (weights ), jax .numpy .array (input ))
47
46
48
47
49
48
def test_layer ():
@@ -66,9 +65,8 @@ def soft(weights, input):
66
65
def hard (weights , input ):
67
66
return hard_and .hard_and_layer (weights , input )
68
67
69
-
70
68
utils .check_consistency (soft , hard , jax .numpy .array (expected ),
71
- jax .numpy .array (weights ), jax .numpy .array (input ))
69
+ jax .numpy .array (weights ), jax .numpy .array (input ))
72
70
73
71
74
72
def test_and ():
@@ -172,7 +170,7 @@ def test_net(type, x):
172
170
soft_input = jax .numpy .array ([0.6 , 0.45 ])
173
171
weights = soft .init (random .PRNGKey (0 ), soft_input )
174
172
soft_result = soft .apply (weights , numpy .array (soft_input ))
175
-
173
+
176
174
# Compute hard result
177
175
hard_weights = harden .hard_weights (weights )
178
176
hard_input = harden .harden (soft_input )
@@ -189,25 +187,24 @@ def test_net(type, x):
189
187
symbolic_input = ['True' , 'False' ]
190
188
symbolic_weights = symbolic_generation .make_symbolic (hard_weights )
191
189
symbolic_output = symbolic .apply (symbolic_weights , symbolic_input )
192
- symbolic_output = symbolic_generation .eval_symbolic_expression (symbolic_output )
190
+ symbolic_output = symbolic_generation .eval_symbolic_expression (
191
+ symbolic_output )
193
192
# Check that the symbolic result is the same as the hard result
194
193
assert numpy .array_equal (symbolic_output , hard_result )
195
194
196
195
# Compute symbolic result with symbolic inputs and non-symbolic weights
197
196
symbolic_input = ['x1' , 'x2' ]
198
197
symbolic_output = symbolic .apply (hard_weights , symbolic_input )
199
198
# Check the form of the symbolic expression
200
- assert numpy .array_equal (symbolic_output , ['True and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or False) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True)' ,
201
- ' True and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or False) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True)' ,
202
- ' True and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or False) != 0) or True) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True)' ,
203
- ' True and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or True) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or False) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False) and (( True and ( (x1 != 0) or False) and ( (x2 != 0) or True) != 0) or False)' ])
199
+ assert numpy .array_equal (symbolic_output , ['numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True) )' ,
200
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True) )' ,
201
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), False)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True) )' ,
202
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and( True, numpy.logical_or(lax_reference.ne (x1, 0), False)), numpy.logical_or(lax_reference.ne (x2, 0), True)), 0), False) )' ])
204
203
205
204
# Compute symbolic result with symbolic inputs and symbolic weights
206
205
symbolic_output = symbolic .apply (symbolic_weights , symbolic_input )
207
206
# Check the form of the symbolic expression
208
- assert numpy .array_equal (symbolic_output , ['True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))' ,
209
- 'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))' ,
210
- 'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))' ,
211
- 'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0)))' ])
212
-
213
-
207
+ assert numpy .array_equal (symbolic_output , ['numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))' ,
208
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))' ,
209
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))' ,
210
+ 'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0))))' ])
0 commit comments