Skip to content

Commit a76529f

Browse files
committed
maintain tests
1 parent 07a7d12 commit a76529f

File tree

5 files changed

+116
-111
lines changed

5 files changed

+116
-111
lines changed

neurallogic/symbolic_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ def eval_symbolic_expression(x: str):
308308
return eval(eval_str)
309309

310310

311-
#@dispatch
312-
#def eval_symbolic_expression(x: numpy.ndarray):
313-
# return numpy.vectorize(eval)(x)
311+
@dispatch
312+
def eval_symbolic_expression(x: numpy.ndarray):
313+
return numpy.vectorize(eval_symbolic_expression)(x)
314314

315315

316316
#@dispatch

neurallogic/symbolic_primitives.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import numpy
2-
from plum import dispatch
31
import typing
2+
43
import jax
54
import jax._src.lax_reference as lax_reference
5+
import numpy
6+
from plum import dispatch
67

78

89
# TODO: remove me?
@@ -175,6 +176,15 @@ def symbolic_operator(operator: str, x: bool, y: str):
175176
@dispatch
176177
def symbolic_operator(operator: str, x: str, y: numpy.ndarray):
177178
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
179+
@dispatch
180+
def symbolic_operator(operator: str, x: str, y: int):
181+
return symbolic_operator(operator, x, str(y))
182+
@dispatch
183+
def symbolic_operator(operator: str, x: list, y: numpy.ndarray):
184+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
185+
@dispatch
186+
def symbolic_operator(operator: str, x: numpy.ndarray, y: jax.numpy.ndarray):
187+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
178188
# XXX
179189

180190

@@ -188,6 +198,7 @@ def symbolic_operator(operator: str, x: list):
188198
return symbolic_operator(operator, numpy.array(x))
189199

190200

201+
# TODO: remove infix_operator?
191202
@dispatch
192203
def symbolic_infix_operator(operator: str, a: str, b: str) -> str:
193204
return f'{a} {operator} {b}'.replace('\'', '')

tests/test_hard_and.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tests import utils
1010

1111

12-
1312
def test_include():
1413
test_data = [
1514
[[1.0, 1.0], 1.0],
@@ -23,7 +22,7 @@ def test_include():
2322
]
2423
for input, expected in test_data:
2524
utils.check_consistency(hard_and.soft_and_include, hard_and.hard_and_include,
26-
expected, input[0], input[1])
25+
expected, input[0], input[1])
2726

2827

2928
def test_neuron():
@@ -43,7 +42,7 @@ def hard(weights, input):
4342
return hard_and.hard_and_neuron(weights, input)
4443

4544
utils.check_consistency(soft, hard, expected,
46-
jax.numpy.array(weights), jax.numpy.array(input))
45+
jax.numpy.array(weights), jax.numpy.array(input))
4746

4847

4948
def test_layer():
@@ -66,9 +65,8 @@ def soft(weights, input):
6665
def hard(weights, input):
6766
return hard_and.hard_and_layer(weights, input)
6867

69-
7068
utils.check_consistency(soft, hard, jax.numpy.array(expected),
71-
jax.numpy.array(weights), jax.numpy.array(input))
69+
jax.numpy.array(weights), jax.numpy.array(input))
7270

7371

7472
def test_and():
@@ -172,7 +170,7 @@ def test_net(type, x):
172170
soft_input = jax.numpy.array([0.6, 0.45])
173171
weights = soft.init(random.PRNGKey(0), soft_input)
174172
soft_result = soft.apply(weights, numpy.array(soft_input))
175-
173+
176174
# Compute hard result
177175
hard_weights = harden.hard_weights(weights)
178176
hard_input = harden.harden(soft_input)
@@ -189,25 +187,24 @@ def test_net(type, x):
189187
symbolic_input = ['True', 'False']
190188
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
191189
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
192-
symbolic_output = symbolic_generation.eval_symbolic_expression(symbolic_output)
190+
symbolic_output = symbolic_generation.eval_symbolic_expression(
191+
symbolic_output)
193192
# Check that the symbolic result is the same as the hard result
194193
assert numpy.array_equal(symbolic_output, hard_result)
195194

196195
# Compute symbolic result with symbolic inputs and non-symbolic weights
197196
symbolic_input = ['x1', 'x2']
198197
symbolic_output = symbolic.apply(hard_weights, symbolic_input)
199198
# Check the form of the symbolic expression
200-
assert numpy.array_equal(symbolic_output, ['True and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True) and ((True and ((x1 != 0) or False) and ((x2 != 0) or False) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True)',
201-
'True and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or False) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True)',
202-
'True and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True) and ((True and ((x1 != 0) or False) and ((x2 != 0) or False) != 0) or True) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True)',
203-
'True and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or True) and ((True and ((x1 != 0) or False) and ((x2 != 0) or False) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False) and ((True and ((x1 != 0) or False) and ((x2 != 0) or True) != 0) or False)'])
199+
assert numpy.array_equal(symbolic_output, ['numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True))',
200+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True))',
201+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), False)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True))',
202+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), True)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), False)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False)), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), False)), numpy.logical_or(lax_reference.ne(x2, 0), True)), 0), False))'])
204203

205204
# Compute symbolic result with symbolic inputs and symbolic weights
206205
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
207206
# Check the form of the symbolic expression
208-
assert numpy.array_equal(symbolic_output, ['True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))',
209-
'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))',
210-
'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0)))',
211-
'True and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((False != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((True != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0))) and ((True and ((x1 != 0) or not((True != 0))) and ((x2 != 0) or not((False != 0))) != 0) or not((True != 0)))'])
212-
213-
207+
assert numpy.array_equal(symbolic_output, ['numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))',
208+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))',
209+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0))))',
210+
'numpy.logical_and(numpy.logical_and(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(False, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(True, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(numpy.logical_and(numpy.logical_and(True, numpy.logical_or(lax_reference.ne(x1, 0), numpy.logical_not(lax_reference.ne(True, 0)))), numpy.logical_or(lax_reference.ne(x2, 0), numpy.logical_not(lax_reference.ne(False, 0)))), 0), numpy.logical_not(lax_reference.ne(True, 0))))'])

0 commit comments

Comments
 (0)