Skip to content

Commit 8ae5251

Browse files
committed
more refactoring and renaming
1 parent 90dab25 commit 8ae5251

File tree

6 files changed

+60
-60
lines changed

6 files changed

+60
-60
lines changed

neurallogic/hard_and.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Callable
77

88

9-
from neurallogic import neural_logic_net, sym_gen
9+
from neurallogic import neural_logic_net, symbolic_generation
1010

1111

1212
def soft_and_include(w: float, x: float) -> float:
@@ -100,8 +100,8 @@ def __init__(self, layer_size):
100100
self.hard_and_layer = HardAndLayer(self.layer_size)
101101

102102
def __call__(self, x):
103-
jaxpr = sym_gen.make_symbolic_flax_jaxpr(self.hard_and_layer, x)
104-
return sym_gen.symbolic_expression(jaxpr, x)
103+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_and_layer, x)
104+
return symbolic_generation.symbolic_expression(jaxpr, x)
105105

106106

107107
and_layer = neural_logic_net.select(
File renamed without changes.

tests/test_hard_and.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy
77
import typing
88

9-
from neurallogic import hard_and, harden, neural_logic_net, sym_gen
9+
from neurallogic import hard_and, harden, neural_logic_net, symbolic_generation
1010

1111

1212
def check_consistency(soft: typing.Callable, hard: typing.Callable, expected, *args):
@@ -19,8 +19,8 @@ def check_consistency(soft: typing.Callable, hard: typing.Callable, expected, *a
1919
assert numpy.allclose(hard(*hard_args), hard_expected)
2020

2121
# Check that the jaxpr performs as expected
22-
symbolic_f = sym_gen.make_symbolic_jaxpr(hard, *hard_args)
23-
assert numpy.allclose(sym_gen.eval_symbolic(
22+
symbolic_f = symbolic_generation.make_symbolic_jaxpr(hard, *hard_args)
23+
assert numpy.allclose(symbolic_generation.eval_symbolic(
2424
symbolic_f, *hard_args), hard_expected)
2525

2626

@@ -209,7 +209,7 @@ def test_net(type, x):
209209
'True and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or True) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or False) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False) and (True and (x1 != 0.0 or False) and (x2 != 0.0 or True) != 0.0 or False)'])
210210

211211
# Compute symbolic result with symbolic inputs and symbolic weights
212-
symbolic_weights = sym_gen.make_symbolic(hard_weights)
212+
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
213213
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
214214
# Check the form of the symbolic expression
215215
assert numpy.array_equal(symbolic_output, ['True and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(True != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(True != 0.0)) and (True and (x1 != 0.0 or not(True != 0.0)) and (x2 != 0.0 or not(False != 0.0)) != 0.0 or not(False != 0.0))',
@@ -220,7 +220,7 @@ def test_net(type, x):
220220
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
221221
symbolic_input = ['True', 'False']
222222
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
223-
symbolic_output = sym_gen.eval_symbolic_expression(symbolic_output)
223+
symbolic_output = symbolic_generation.eval_symbolic_expression(symbolic_output)
224224
# Check that the symbolic result is the same as the hard result
225225
assert numpy.array_equal(symbolic_output, hard_result)
226226

tests/test_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from flax.training import train_state
1111
import ml_collections
1212
from neurallogic import (hard_not, hard_or, harden, harden_layer,
13-
neural_logic_net, primitives)
13+
neural_logic_net)
1414
import optax
1515

1616

tests/test_sym_gen.py renamed to tests/test_symbolic_generation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_not, sym_gen, primitives, symbolic_primitives
1+
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_not, symbolic_generation, symbolic_primitives
22
from tests import test_mnist
33
import numpy
44
import jax
@@ -8,14 +8,14 @@
88
def nln(type, x, width):
99
x = hard_or.or_layer(type)(width)(x)
1010
x = hard_not.not_layer(type)(10)(x)
11-
x = primitives.nl_ravel(type)(x)
11+
x = x.ravel()
1212
x = harden_layer.harden_layer(type)(x)
13-
x = primitives.nl_reshape(type)((10, width))(x)
14-
x = primitives.nl_sum(type)(-1)(x)
13+
x = x.reshape((10, width))
14+
x = x.sum(-1)
1515
return x
1616

1717

18-
def test_sym_gen():
18+
def test_symbolic_generation():
1919
# Get MNIST dataset
2020
train_ds, test_ds = test_mnist.get_datasets()
2121
# Flatten images
@@ -40,28 +40,28 @@ def test_sym_gen():
4040
hard_output = hard.apply(hard_weights, hard_mock_input)
4141

4242
# Create a jaxpr from the neural logic net (with an arbitrary image input to set sizes)
43-
symbolic_net = sym_gen.make_symbolic(lambda x: hard.apply(hard_weights, x), harden.harden(test_ds['image'][0]))
43+
symbolic_net = symbolic_generation.make_symbolic(lambda x: hard.apply(hard_weights, x), harden.harden(test_ds['image'][0]))
4444

4545
# -- TEST 1: Compare the standard evaluation of the network with the non-standard evaluation of the jaxpr
4646
# Evaluate the jaxpr with the hard input
47-
eval_hard_output = sym_gen.eval_symbolic(symbolic_net, hard_mock_input)
47+
eval_hard_output = symbolic_generation.eval_symbolic(symbolic_net, hard_mock_input)
4848
# If this assertion succeeds then the non-standard evaluation of the jaxpr is is identical to the standard evaluation of network
4949
assert numpy.array_equal(eval_hard_output, hard_output)
5050

5151
# -- TEST 2: Compare the standard evaluation of the network with the non-standard symbolic evaluation of the jaxpr
5252
# Convert the hard input to a symbolic input
5353
# TODO: move this conversion into compute_symbolic_output
54-
symbolic_mock_input = sym_gen.make_symbolic(
54+
symbolic_mock_input = symbolic_generation.make_symbolic(
5555
numpy.array(hard_mock_input, dtype=object))
5656
# Symbolically evaluate the jaxpr with the symbolic input
57-
symbolic_output = sym_gen.symbolic_expression(
57+
symbolic_output = symbolic_generation.symbolic_expression(
5858
symbolic_net, symbolic_mock_input)
5959
# If this assertion succeeds then the shape of the non-standard symbolic evaluation of the jaxpr
6060
# is identical to the shape of the standard evaluation of the jaxpr
6161
assert numpy.array_equal(hard_output.shape,
6262
symbolic_output.shape)
6363
# Compute the symbolic expression, i.e. perform the actual operations in the symbolic expression
64-
eval_symbolic_output = sym_gen.eval_symbolic_expression(
64+
eval_symbolic_output = symbolic_generation.eval_symbolic_expression(
6565
symbolic_output)
6666
# If this assertion succeeds then the non-standard symbolic evaluation of the jaxpr is is identical to the standard evaluation of network
6767
assert numpy.array_equal(hard_output, eval_symbolic_output)

tests/test_symbolic_primitives.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import numpy
22
import jax
3-
from neurallogic import symbolic_primitives, sym_gen
3+
from neurallogic import symbolic_primitives, symbolic_generation
44

55

66
def test_unary_operator_str():
77
output = symbolic_primitives.unary_operator("not", "True")
88
expected = "not(True)"
99
assert output == expected
10-
eval_output = sym_gen.eval_symbolic_expression(output)
11-
eval_expected = sym_gen.eval_symbolic_expression(expected)
10+
eval_output = symbolic_generation.eval_symbolic_expression(output)
11+
eval_expected = symbolic_generation.eval_symbolic_expression(expected)
1212
assert eval_output == eval_expected
1313

1414

@@ -17,8 +17,8 @@ def test_unary_operator_vector():
1717
output = symbolic_primitives.unary_operator("not", x)
1818
expected = numpy.array(["not(True)", "not(False)"])
1919
assert numpy.array_equal(output, expected)
20-
eval_output = sym_gen.eval_symbolic_expression(output)
21-
eval_expected = sym_gen.eval_symbolic_expression(expected)
20+
eval_output = symbolic_generation.eval_symbolic_expression(output)
21+
eval_expected = symbolic_generation.eval_symbolic_expression(expected)
2222
assert numpy.array_equal(eval_output, eval_expected)
2323

2424

@@ -28,18 +28,18 @@ def test_unary_operator_matrix():
2828
expected = numpy.array(
2929
[["not(True)", "not(False)"], ["not(False)", "not(True)"]])
3030
assert numpy.array_equal(output, expected)
31-
eval_output = sym_gen.eval_symbolic_expression(output)
32-
eval_expected = sym_gen.eval_symbolic_expression(expected)
31+
eval_output = symbolic_generation.eval_symbolic_expression(output)
32+
eval_expected = symbolic_generation.eval_symbolic_expression(expected)
3333
assert numpy.array_equal(eval_output, eval_expected)
3434

3535

3636
def test_binary_operator_str_str():
3737
output = symbolic_primitives.binary_infix_operator("+", "1", "2")
3838
expected = "1 + 2"
3939
assert output == expected
40-
eval_output = sym_gen.eval_symbolic_expression(output)
41-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
42-
"1"), sym_gen.eval_symbolic_expression("2"))
40+
eval_output = symbolic_generation.eval_symbolic_expression(output)
41+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
42+
"1"), symbolic_generation.eval_symbolic_expression("2"))
4343
assert numpy.array_equal(eval_output, numpy_output)
4444

4545

@@ -49,9 +49,9 @@ def test_binary_operator_vector_vector():
4949
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
5050
expected = numpy.array(["1 + 3", "2 + 4"])
5151
assert numpy.array_equal(output, expected)
52-
eval_output = sym_gen.eval_symbolic_expression(output)
53-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
54-
x1), sym_gen.eval_symbolic_expression(x2))
52+
eval_output = symbolic_generation.eval_symbolic_expression(output)
53+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
54+
x1), symbolic_generation.eval_symbolic_expression(x2))
5555
assert numpy.array_equal(eval_output, numpy_output)
5656

5757

@@ -61,9 +61,9 @@ def test_binary_operator_matrix_vector():
6161
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
6262
expected = numpy.array([["1 + 5", "2 + 6"], ["3 + 5", "4 + 6"]])
6363
assert numpy.array_equal(output, expected)
64-
eval_output = sym_gen.eval_symbolic_expression(output)
65-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
66-
x1), sym_gen.eval_symbolic_expression(x2))
64+
eval_output = symbolic_generation.eval_symbolic_expression(output)
65+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
66+
x1), symbolic_generation.eval_symbolic_expression(x2))
6767
assert numpy.array_equal(eval_output, numpy_output)
6868

6969

@@ -73,9 +73,9 @@ def test_binary_operator_vector_matrix():
7373
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
7474
expected = numpy.array([["1 + 3", "2 + 4"], ["1 + 5", "2 + 6"]])
7575
assert numpy.array_equal(output, expected)
76-
eval_output = sym_gen.eval_symbolic_expression(output)
77-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
78-
x1), sym_gen.eval_symbolic_expression(x2))
76+
eval_output = symbolic_generation.eval_symbolic_expression(output)
77+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
78+
x1), symbolic_generation.eval_symbolic_expression(x2))
7979
assert numpy.array_equal(eval_output, numpy_output)
8080

8181

@@ -85,9 +85,9 @@ def test_binary_operator_matrix_matrix():
8585
output = symbolic_primitives.binary_infix_operator("+", x1, x2)
8686
expected = numpy.array([["1 + 5", "2 + 6"], ["3 + 7", "4 + 8"]])
8787
assert numpy.array_equal(output, expected)
88-
eval_output = sym_gen.eval_symbolic_expression(output)
89-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
90-
x1), sym_gen.eval_symbolic_expression(x2))
88+
eval_output = symbolic_generation.eval_symbolic_expression(output)
89+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
90+
x1), symbolic_generation.eval_symbolic_expression(x2))
9191
assert numpy.array_equal(eval_output, numpy_output)
9292

9393

@@ -100,9 +100,9 @@ def test_binary_operator_matrix_matrix_2():
100100
expected = numpy.array(
101101
[["1 + 5", "2 + 6", "3 + 7", "4 + 8"] for _ in range(10)])
102102
assert numpy.array_equal(output, expected)
103-
eval_output = sym_gen.eval_symbolic_expression(output)
104-
numpy_output = numpy.add(sym_gen.eval_symbolic_expression(
105-
x1), sym_gen.eval_symbolic_expression(x2))
103+
eval_output = symbolic_generation.eval_symbolic_expression(output)
104+
numpy_output = numpy.add(symbolic_generation.eval_symbolic_expression(
105+
x1), symbolic_generation.eval_symbolic_expression(x2))
106106
assert numpy.array_equal(eval_output, numpy_output)
107107

108108

@@ -141,34 +141,34 @@ def test_to_boolean_value_string():
141141

142142

143143
def test_symbolic_eval():
144-
output = sym_gen.eval_symbolic_expression("1 + 2")
144+
output = symbolic_generation.eval_symbolic_expression("1 + 2")
145145
expected = 3
146146
assert output == expected
147-
output = sym_gen.eval_symbolic_expression("[1, 2, 3]")
147+
output = symbolic_generation.eval_symbolic_expression("[1, 2, 3]")
148148
expected = [1, 2, 3]
149149
assert numpy.array_equal(output, expected)
150-
output = sym_gen.eval_symbolic_expression("[1, 2, 3] + [4, 5, 6]")
150+
output = symbolic_generation.eval_symbolic_expression("[1, 2, 3] + [4, 5, 6]")
151151
expected = [1, 2, 3, 4, 5, 6]
152152
assert numpy.array_equal(output, expected)
153-
output = sym_gen.eval_symbolic_expression(['1', '2', '3'])
153+
output = symbolic_generation.eval_symbolic_expression(['1', '2', '3'])
154154
expected = [1, 2, 3]
155155
assert numpy.array_equal(output, expected)
156-
output = sym_gen.eval_symbolic_expression(
156+
output = symbolic_generation.eval_symbolic_expression(
157157
['1', '2', '3'] + ['4', '5', '6'])
158158
expected = [1, 2, 3, 4, 5, 6]
159159
assert numpy.array_equal(output, expected)
160-
output = sym_gen.eval_symbolic_expression(['not(False)', 'not(True)'])
160+
output = symbolic_generation.eval_symbolic_expression(['not(False)', 'not(True)'])
161161
expected = [True, False]
162162
assert numpy.array_equal(output, expected)
163-
output = sym_gen.eval_symbolic_expression(
163+
output = symbolic_generation.eval_symbolic_expression(
164164
[['not(False)', 'not(True)'] + ['not(False)', 'not(True)']])
165165
expected = [[True, False, True, False]]
166166
assert numpy.array_equal(output, expected)
167-
output = sym_gen.eval_symbolic_expression(numpy.array(
167+
output = symbolic_generation.eval_symbolic_expression(numpy.array(
168168
[['not(False)', 'not(True)'] + ['not(False)', 'not(True)']]))
169169
expected = [[True, False, True, False]]
170170
assert numpy.array_equal(output, expected)
171-
output = sym_gen.eval_symbolic_expression(numpy.array(
171+
output = symbolic_generation.eval_symbolic_expression(numpy.array(
172172
[['not(False)', False], ['not(False)', 'not(True)']]))
173173
expected = [[True, False], [True, False]]
174174
assert numpy.array_equal(output, expected)
@@ -179,7 +179,7 @@ def test_symbolic_not():
179179
output = symbolic_primitives.symbolic_not(x1)
180180
expected = numpy.array([False, True])
181181
assert numpy.array_equal(output, expected)
182-
x1 = sym_gen.make_symbolic(x1)
182+
x1 = symbolic_generation.make_symbolic(x1)
183183
output = symbolic_primitives.symbolic_not(x1)
184184
expected = numpy.array(["not(True)", "not(False)"])
185185
assert numpy.array_equal(output, expected)
@@ -191,8 +191,8 @@ def test_symbolic_and():
191191
output = symbolic_primitives.symbolic_and(x1, x2)
192192
expected = numpy.array([True, False])
193193
assert numpy.array_equal(output, expected)
194-
x1 = sym_gen.make_symbolic(x1)
195-
x2 = sym_gen.make_symbolic(x2)
194+
x1 = symbolic_generation.make_symbolic(x1)
195+
x2 = symbolic_generation.make_symbolic(x2)
196196
output = symbolic_primitives.symbolic_and(x1, x2)
197197
expected = numpy.array(["True and True", "False and True"])
198198
assert numpy.array_equal(output, expected)
@@ -204,8 +204,8 @@ def test_symbolic_xor():
204204
output = symbolic_primitives.symbolic_xor(x1, x2)
205205
expected = numpy.array([False, True])
206206
assert numpy.array_equal(output, expected)
207-
x1 = sym_gen.make_symbolic(x1)
208-
x2 = sym_gen.make_symbolic(x2)
207+
x1 = symbolic_generation.make_symbolic(x1)
208+
x2 = symbolic_generation.make_symbolic(x2)
209209
output = symbolic_primitives.symbolic_xor(x1, x2)
210210
expected = numpy.array(["True ^ True", "False ^ True"])
211211
assert numpy.array_equal(output, expected)
@@ -242,12 +242,12 @@ def symbolic_reduce_or_impl(input, expected, symbolic_expected, axes):
242242
expected = numpy.array(expected)
243243
assert numpy.array_equal(output, expected)
244244
# Test symbolic implementation
245-
input = sym_gen.make_symbolic(input)
245+
input = symbolic_generation.make_symbolic(input)
246246
output = symbolic_primitives.symbolic_reduce_or(input, axes=axes)
247247
symbolic_expected = numpy.array(symbolic_expected)
248248
assert numpy.array_equal(output, symbolic_expected)
249249
# Compare the reference and symbolic evaluation
250-
symbolic_expected = sym_gen.eval_symbolic_expression(symbolic_expected)
250+
symbolic_expected = symbolic_generation.eval_symbolic_expression(symbolic_expected)
251251
assert numpy.array_equal(expected, symbolic_expected)
252252

253253

0 commit comments

Comments
 (0)