Skip to content

Commit 7d5f635

Browse files
committed
symbolic version of majority, plus additonal tests
1 parent 5123507 commit 7d5f635

13 files changed

+133
-97
lines changed

neurallogic/hard_majority.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
2+
from flax import linen as nn
23

3-
from neurallogic import neural_logic_net
4+
from neurallogic import neural_logic_net, symbolic_generation
45

56

67
def majority_index(input_size: int) -> int:
@@ -23,9 +24,40 @@ def hard_majority(x: jax.numpy.array) -> bool:
2324
hard_majority_layer = jax.vmap(hard_majority, in_axes=0)
2425

2526

26-
def symbolic_majority_layer(x):
27-
return hard_majority_layer(x)
27+
class SoftMajorityLayer(nn.Module):
28+
"""
29+
A soft-bit MAJORITY layer than transforms its inputs along the last dimension.
30+
31+
Attributes:
32+
layer_size: The number of neurons in the layer.
33+
weights_init: The initializer function for the weight matrix.
34+
"""
35+
dtype: jax.numpy.dtype = jax.numpy.float32
36+
37+
@nn.compact
38+
def __call__(self, x):
39+
x = jax.numpy.asarray(x, self.dtype) # TODO: remove me?
40+
return soft_majority_layer(x)
41+
42+
43+
class HardMajorityLayer(nn.Module):
44+
@nn.compact
45+
def __call__(self, x):
46+
return hard_majority_layer(x)
47+
48+
49+
class SymbolicMajorityLayer:
50+
def __init__(self):
51+
self.hard_majority_layer = HardMajorityLayer()
52+
53+
def __call__(self, x):
54+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
55+
self.hard_majority_layer, x)
56+
return symbolic_generation.symbolic_expression(jaxpr, x)
2857

2958

3059
majority_layer = neural_logic_net.select(
31-
soft_majority_layer, hard_majority_layer, symbolic_majority_layer)
60+
lambda dtype=jax.numpy.float32: SoftMajorityLayer(dtype),
61+
lambda dtype=jax.numpy.float32: HardMajorityLayer(),
62+
lambda dtype=jax.numpy.float32: SymbolicMajorityLayer()
63+
)

neurallogic/hard_or.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __call__(self, x):
6161
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
6262
weights = self.param(
6363
'bit_weights', self.weights_init, weights_shape, self.dtype)
64-
x = jax.numpy.asarray(x, self.dtype)
64+
x = jax.numpy.asarray(x, self.dtype) # TODO: remove me?
6565
return soft_or_layer(weights, x)
6666

6767

neurallogic/hard_xor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def soft_xor_neuron(w, x):
2828

2929
def xor(x, y):
3030
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
31-
x = jax.lax.reduce(x, jax.numpy.array(0.0), xor, (0,))
31+
x = jax.lax.reduce(x, jax.numpy.array(0, dtype=x.dtype), xor, (0,))
3232
return x
3333

3434

@@ -39,21 +39,21 @@ def hard_xor_neuron(w, x):
3939

4040
soft_xor_layer = jax.vmap(soft_xor_neuron, (0, None), 0)
4141

42+
4243
hard_xor_layer = jax.vmap(hard_xor_neuron, (0, None), 0)
4344

4445

4546
class SoftXorLayer(nn.Module):
4647
layer_size: int
47-
weights_init: Callable = nn.initializers.uniform(
48-
1.0) # TODO: investigate better initialization
48+
weights_init: Callable = nn.initializers.uniform(1.0)
4949
dtype: jax.numpy.dtype = jax.numpy.float32
5050

5151
@nn.compact
5252
def __call__(self, x):
5353
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
5454
weights = self.param(
5555
'bit_weights', self.weights_init, weights_shape, self.dtype)
56-
x = jax.numpy.asarray(x, self.dtype)
56+
x = jax.numpy.asarray(x, self.dtype) # TODO is this needed?
5757
return soft_xor_layer(weights, x)
5858

5959

neurallogic/harden.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ def harden(x: flax.core.FrozenDict):
6262
return harden(x.unfreeze())
6363

6464

65+
"""
6566
@dispatch
6667
def harden(*args):
6768
if len(args) == 1:
68-
return harden(args[0])
69+
print(f'args = {args} of type {type(args)}')
70+
arg = args[0]
71+
print(f'args[0] = {arg}')
72+
return tuple(harden(arg))
6973
return tuple([harden(arg) for arg in args])
70-
74+
"""
7175

7276
@dispatch
7377
def map_keys_nested(f, d: dict) -> dict:

neurallogic/map_at_elements.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def map_at_elements(x: numpy.float32, func: typing.Callable):
3030
return func(x)
3131

3232

33+
@dispatch
34+
def map_at_elements(x: numpy.int32, func: typing.Callable):
35+
return func(x)
36+
37+
3338
@dispatch
3439
def map_at_elements(x: list, func: typing.Callable):
3540
return [map_at_elements(item, func) for item in x]
@@ -55,4 +60,3 @@ def map_at_elements(x: dict, func: typing.Callable):
5560
@dispatch
5661
def map_at_elements(x: tuple, func: typing.Callable):
5762
return tuple(map_at_elements(list(x), func))
58-

neurallogic/neural_logic_net.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import Enum
22
from flax import linen as nn
3+
from neurallogic import symbolic_generation
34

45
NetType = Enum('NetType', ['Soft', 'Hard', 'Symbolic'])
56

neurallogic/symbolic_generation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
100100
return jaxpr
101101

102102

103+
103104
def eval_jaxpr(symbolic, jaxpr, consts, *args):
104105
'''Evaluates a jaxpr by interpreting it as Python code.
105106
@@ -212,15 +213,15 @@ def eval_jaxpr_impl(jaxpr):
212213
def make_symbolic_jaxpr(func: typing.Callable, *args):
213214
return jax.make_jaxpr(lambda *args: func(*args))(*args)
214215

215-
216-
def eval_symbolic(symbolic_function, *args):
217-
if hasattr(symbolic_function, 'literals'):
216+
# TODO: better name
217+
def eval_symbolic(jaxpr, *args):
218+
if hasattr(jaxpr, 'literals'):
218219
return eval_jaxpr(
219-
False, symbolic_function.jaxpr, symbolic_function.literals, *args
220+
False, jaxpr.jaxpr, jaxpr.literals, *args
220221
)
221-
return eval_jaxpr(False, symbolic_function.jaxpr, [], *args)
222-
222+
return eval_jaxpr(False, jaxpr.jaxpr, [], *args)
223223

224+
# TODO: better name
224225
def symbolic_expression(jaxpr, *args):
225226
if hasattr(jaxpr, 'literals'):
226227
sym_expr = eval_jaxpr(True, jaxpr.jaxpr, jaxpr.literals, *args)

neurallogic/symbolic_operator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def symbolic_operator(operator: str, x: str, y: int):
6363
return symbolic_operator(operator, x, str(y))
6464

6565

66+
@dispatch
67+
def symbolic_operator(operator: str, x: tuple):
68+
return symbolic_operator(operator, str(x))
69+
70+
6671
@dispatch
6772
def symbolic_operator(operator: str, x: list, y: numpy.ndarray):
6873
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)

neurallogic/symbolic_primitives.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,83 +21,86 @@ def all_concrete_values(data):
2121
return True
2222

2323

24-
def symbolic_f(concrete_f: Callable, symbolic_f_name: str, *args, **kwargs):
24+
def symbolic(concrete_function: Callable, symbolic_function: str, *args, **kwargs):
2525
if all_concrete_values([*args]):
26-
return concrete_f(*args, **kwargs)
26+
# We can directly evaluate the function
27+
return concrete_function(*args, **kwargs)
2728
else:
28-
return symbolic_operator.symbolic_operator(symbolic_f_name, *args, **kwargs)
29+
# We need to return a symbolic representation
30+
return symbolic_operator.symbolic_operator(symbolic_function, *args, **kwargs)
2931

3032

3133
def symbolic_not(*args, **kwargs):
32-
return symbolic_f(numpy.logical_not, 'numpy.logical_not', *args, **kwargs)
34+
return symbolic(numpy.logical_not, 'numpy.logical_not', *args, **kwargs)
3335

3436

3537
def symbolic_eq(*args, **kwargs):
36-
return symbolic_f(lax_reference.eq, 'lax_reference.eq', *args, **kwargs)
38+
return symbolic(lax_reference.eq, 'lax_reference.eq', *args, **kwargs)
3739

3840

3941
def symbolic_ne(*args, **kwargs):
40-
return symbolic_f(lax_reference.ne, 'lax_reference.ne', *args, **kwargs)
42+
return symbolic(lax_reference.ne, 'lax_reference.ne', *args, **kwargs)
4143

4244

4345
def symbolic_le(*args, **kwargs):
44-
return symbolic_f(lax_reference.le, 'lax_reference.le', *args, **kwargs)
46+
return symbolic(lax_reference.le, 'lax_reference.le', *args, **kwargs)
4547

4648

4749
def symbolic_lt(*args, **kwargs):
48-
return symbolic_f(lax_reference.lt, 'lax_reference.lt', *args, **kwargs)
50+
return symbolic(lax_reference.lt, 'lax_reference.lt', *args, **kwargs)
4951

5052

5153
def symbolic_ge(*args, **kwargs):
52-
return symbolic_f(lax_reference.ge, 'lax_reference.ge', *args, **kwargs)
54+
return symbolic(lax_reference.ge, 'lax_reference.ge', *args, **kwargs)
5355

5456

5557
def symbolic_gt(*args, **kwargs):
56-
return symbolic_f(lax_reference.gt, 'lax_reference.gt', *args, **kwargs)
58+
return symbolic(lax_reference.gt, 'lax_reference.gt', *args, **kwargs)
5759

5860

5961
def symbolic_abs(*args, **kwargs):
60-
return symbolic_f(lax_reference.abs, 'numpy.absolute', *args, **kwargs)
62+
return symbolic(lax_reference.abs, 'numpy.absolute', *args, **kwargs)
6163

6264

6365
def symbolic_add(*args, **kwargs):
64-
return symbolic_f(lax_reference.add, 'numpy.add', *args, **kwargs)
66+
return symbolic(lax_reference.add, 'numpy.add', *args, **kwargs)
6567

6668

6769
def symbolic_sub(*args, **kwargs):
68-
return symbolic_f(lax_reference.sub, 'numpy.subtract', *args, **kwargs)
70+
return symbolic(lax_reference.sub, 'numpy.subtract', *args, **kwargs)
6971

7072

7173
def symbolic_mul(*args, **kwargs):
72-
return symbolic_f(lax_reference.mul, 'numpy.multiply', *args, **kwargs)
74+
return symbolic(lax_reference.mul, 'numpy.multiply', *args, **kwargs)
7375

7476

7577
def symbolic_div(*args, **kwargs):
76-
return symbolic_f(lax_reference.div, 'lax_reference.div', *args, **kwargs)
78+
return symbolic(lax_reference.div, 'lax_reference.div', *args, **kwargs)
7779

7880

7981
def symbolic_max(*args, **kwargs):
80-
return symbolic_f(lax_reference.max, 'numpy.maximum', *args, **kwargs)
82+
return symbolic(lax_reference.max, 'numpy.maximum', *args, **kwargs)
8183

8284

8385
def symbolic_min(*args, **kwargs):
84-
return symbolic_f(lax_reference.min, 'numpy.minimum', *args, **kwargs)
86+
return symbolic(lax_reference.min, 'numpy.minimum', *args, **kwargs)
8587

8688

8789
def symbolic_and(*args, **kwargs):
88-
return symbolic_f(numpy.logical_and, 'numpy.logical_and', *args, **kwargs)
90+
return symbolic(numpy.logical_and, 'numpy.logical_and', *args, **kwargs)
8991

9092

9193
def symbolic_or(*args, **kwargs):
92-
return symbolic_f(numpy.logical_or, 'numpy.logical_or', *args, **kwargs)
94+
return symbolic(numpy.logical_or, 'numpy.logical_or', *args, **kwargs)
9395

9496

9597
def symbolic_xor(*args, **kwargs):
96-
return symbolic_f(numpy.logical_xor, 'numpy.logical_xor', *args, **kwargs)
98+
return symbolic(numpy.logical_xor, 'numpy.logical_xor', *args, **kwargs)
9799

98100

99101
def symbolic_sum(*args, **kwargs):
100-
return symbolic_f(lax_reference.sum, 'lax_reference.sum', *args, **kwargs)
102+
# N.B. We pass the tuple directly because we're summing over all args
103+
return symbolic(lax_reference.sum, 'lax_reference.sum', args, **kwargs)
101104

102105

103106
def symbolic_broadcast_in_dim(*args, **kwargs):

tests/test_hard_majority.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import jax
33

44
from neurallogic import hard_majority, harden, symbolic_generation
5+
from tests import utils
56

67

78
def test_majority_index():
@@ -126,9 +127,7 @@ def test_hard_majority_layer():
126127
False, True, False, True, False, True], [False, True, False, True, False, True]])) == numpy.array([False, False, False]))
127128

128129

129-
def test_majority_layer():
130-
soft, hard, symbolic = hard_majority.soft_majority_layer, hard_majority.hard_majority_layer, hard_majority.symbolic_majority_layer
131-
130+
def test_layer():
132131
test_data = [
133132
[
134133
[[0.8, 0.1, 0.4], [1.0, 0.0, 0.3]],
@@ -151,16 +150,15 @@ def test_majority_layer():
151150
[0.3, 0.1, 0.2, 0.3, 0.0]
152151
]
153152
]
154-
153+
155154
for input, expected in test_data:
156-
input = jax.numpy.array(input)
157-
expected = jax.numpy.array(expected)
158-
soft_output = soft(input)
159-
assert jax.numpy.array_equal(soft_output, expected)
160-
hard_output = hard(harden.harden(input))
161-
assert jax.numpy.array_equal(hard_output, harden.harden(expected))
162-
jaxpr = symbolic_generation.make_symbolic_jaxpr(symbolic, harden.harden(input))
163-
symbolic_output = symbolic_generation.symbolic_expression(jaxpr, harden.harden(input))
164-
assert jax.numpy.array_equal(symbolic_output, harden.harden(expected))
155+
def soft(input):
156+
return hard_majority.soft_majority_layer(input)
157+
158+
def hard(input):
159+
return hard_majority.hard_majority_layer(input)
160+
161+
utils.check_consistency(soft, hard, jax.numpy.array(expected), jax.numpy.array(input))
162+
165163

166164
# TODO: test training the hard majority layer

0 commit comments

Comments
 (0)