Skip to content

Commit dc48979

Browse files
committed
get hard_or test passing
1 parent 756fdb0 commit dc48979

File tree

3 files changed

+106
-115
lines changed

3 files changed

+106
-115
lines changed

neurallogic/hard_or.py

Lines changed: 10 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax
55
from flax import linen as nn
66

7-
from neurallogic import neural_logic_net
7+
from neurallogic import neural_logic_net, symbolic_generation
88

99

1010
def soft_or_include(w: float, x: float) -> float:
@@ -18,18 +18,10 @@ def soft_or_include(w: float, x: float) -> float:
1818
w = jax.numpy.clip(w, 0.0, 1.0)
1919
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
2020

21-
@jax.jit
22-
def hard_or_include(w: bool, x: bool) -> bool:
23-
return x & w
2421

25-
def symbolic_or_include(w, x):
26-
expression = f"({x} and {w})"
27-
# Check if w is of type bool
28-
if isinstance(w, bool) and isinstance(x, bool):
29-
# We know the value of w and x, so we can evaluate the expression
30-
return eval(expression)
31-
# We don't know the value of w or x, so we return the expression
32-
return expression
22+
def hard_or_include(w, x):
23+
return jax.numpy.logical_and(x, w)
24+
3325

3426
def soft_or_neuron(w, x):
3527
x = jax.vmap(soft_or_include, 0, 0)(w, x)
@@ -39,31 +31,11 @@ def hard_or_neuron(w, x):
3931
x = jax.vmap(hard_or_include, 0, 0)(w, x)
4032
return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0])
4133

42-
def symbolic_or_neuron(w, x):
43-
# TODO: ensure that this implementation has the same generality over tensors as vmap
44-
if not isinstance(w, list):
45-
raise TypeError(f"Input {x} should be a list")
46-
if not isinstance(x, list):
47-
raise TypeError(f"Input {x} should be a list")
48-
y = [symbolic_or_include(wi, xi) for wi, xi in zip(w, x)]
49-
expression = "(" + str(reduce(lambda a, b: f"{a} or {b}", y)) + ")"
50-
if all(isinstance(yi, bool) for yi in y):
51-
# We know the value of all yis, so we can evaluate the expression
52-
return eval(expression)
53-
return expression
5434

5535
soft_or_layer = jax.vmap(soft_or_neuron, (0, None), 0)
5636

5737
hard_or_layer = jax.vmap(hard_or_neuron, (0, None), 0)
5838

59-
def symbolic_or_layer(w, x):
60-
# TODO: ensure that this implementation has the same generality over tensors as vmap
61-
if not isinstance(w, list):
62-
raise TypeError(f"Input {x} should be a list")
63-
if not isinstance(x, list):
64-
raise TypeError(f"Input {x} should be a list")
65-
return [symbolic_or_neuron(wi, x) for wi in w]
66-
6739
# TODO: investigate better initialization
6840
def initialize_near_to_one():
6941
def init(key, shape, dtype):
@@ -77,13 +49,6 @@ def init(key, shape, dtype):
7749
return init
7850

7951
class SoftOrLayer(nn.Module):
80-
"""
81-
A soft-bit Or layer than transforms its inputs along the last dimension.
82-
83-
Attributes:
84-
layer_size: The number of neurons in the layer.
85-
weights_init: The initializer function for the weight matrix.
86-
"""
8752
layer_size: int
8853
weights_init: Callable = initialize_near_to_one()
8954
dtype: jax.numpy.dtype = jax.numpy.float32
@@ -96,13 +61,6 @@ def __call__(self, x):
9661
return soft_or_layer(weights, x)
9762

9863
class HardOrLayer(nn.Module):
99-
"""
100-
A hard-bit Or layer that shadows the SoftAndLayer.
101-
This is a convenience class to make it easier to switch between soft and hard logic.
102-
103-
Attributes:
104-
layer_size: The number of neurons in the layer.
105-
"""
10664
layer_size: int
10765

10866
@nn.compact
@@ -111,21 +69,14 @@ def __call__(self, x):
11169
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
11270
return hard_or_layer(weights, x)
11371

114-
class SymbolicOrLayer(nn.Module):
115-
"""A symbolic Or layer than transforms its inputs along the last dimension.
116-
Attributes:
117-
layer_size: The number of neurons in the layer.
118-
"""
119-
layer_size: int
72+
class SymbolicOrLayer:
73+
def __init__(self, layer_size):
74+
self.layer_size = layer_size
75+
self.hard_or_layer = HardOrLayer(self.layer_size)
12076

121-
@nn.compact
12277
def __call__(self, x):
123-
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
124-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
125-
weights = weights.tolist()
126-
if not isinstance(x, list):
127-
raise TypeError(f"Input {x} should be a list")
128-
return symbolic_or_layer(weights, x)
78+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x)
79+
return symbolic_generation.symbolic_expression(jaxpr, x)
12980

13081
or_layer = neural_logic_net.select(
13182
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),

tests/test_hard_and.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from flax.training import train_state
55
from jax import random
66
import numpy
7-
import typing
87

98
from neurallogic import hard_and, harden, neural_logic_net, symbolic_generation
109
from tests import utils

tests/test_hard_or.py

Lines changed: 96 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from flax import linen as nn
55
from flax.training import train_state
66
from jax import random
7+
import numpy
78

8-
from neurallogic import hard_or, harden, neural_logic_net, primitives
9+
from neurallogic import hard_or, harden, neural_logic_net, symbolic_generation
10+
from tests import utils
911

1012

1113
def test_include():
@@ -20,11 +22,8 @@ def test_include():
2022
[[-0.1, 1.0], 0.0]
2123
]
2224
for input, expected in test_data:
23-
assert hard_or.soft_or_include(*input) == expected
24-
assert hard_or.hard_or_include(
25-
*harden.harden(input)) == harden.harden(expected)
26-
symbolic_output = hard_or.symbolic_or_include(*harden.harden(input))
27-
assert symbolic_output == harden.harden(expected)
25+
utils.check_consistency(hard_or.soft_or_include, hard_or.hard_or_include,
26+
expected, input[0], input[1])
2827

2928

3029
def test_neuron():
@@ -37,14 +36,14 @@ def test_neuron():
3736
[[0.0, 1.0], [1.0, 1.0], 1.0]
3837
]
3938
for input, weights, expected in test_data:
40-
input = jnp.array(input)
41-
weights = jnp.array(weights)
42-
assert jnp.allclose(hard_or.soft_or_neuron(weights, input), expected)
43-
assert jnp.allclose(hard_or.hard_or_neuron(harden.harden(
44-
weights), harden.harden(input)), harden.harden(expected))
45-
symbolic_output = hard_or.symbolic_or_neuron(
46-
harden.harden(weights.tolist()), harden.harden(input.tolist()))
47-
assert jnp.array_equal(symbolic_output, harden.harden(expected))
39+
def soft(weights, input):
40+
return hard_or.soft_or_neuron(weights, input)
41+
42+
def hard(weights, input):
43+
return hard_or.hard_or_neuron(weights, input)
44+
45+
utils.check_consistency(soft, hard, expected,
46+
jax.numpy.array(weights), jax.numpy.array(input))
4847

4948

5049
def test_layer():
@@ -59,27 +58,27 @@ def test_layer():
5958
1.0, 0.0], [0.0, 0.0]], [0.0, 0.0, 0.0, 0.0]]
6059
]
6160
for input, weights, expected in test_data:
62-
input = jnp.array(input)
63-
weights = jnp.array(weights)
64-
expected = jnp.array(expected)
65-
assert jnp.allclose(hard_or.soft_or_layer(weights, input), expected)
66-
assert jnp.allclose(hard_or.hard_or_layer(harden.harden(
67-
weights), harden.harden(input)), harden.harden(expected))
68-
symbolic_output = hard_or.symbolic_or_layer(
69-
harden.harden(weights.tolist()), harden.harden(input.tolist()))
70-
assert jnp.array_equal(symbolic_output, harden.harden(expected))
61+
def soft(weights, input):
62+
return hard_or.soft_or_layer(weights, input)
63+
64+
def hard(weights, input):
65+
return hard_or.hard_or_layer(weights, input)
66+
67+
68+
utils.check_consistency(soft, hard, jax.numpy.array(expected),
69+
jax.numpy.array(weights), jax.numpy.array(input))
7170

7271

7372
def test_or():
7473
def test_net(type, x):
7574
x = hard_or.or_layer(type)(4, nn.initializers.uniform(1.0))(x)
76-
x = primitives.nl_ravel(type)(x)
75+
x = x.ravel()
7776
return x
7877

7978
soft, hard, symbolic = neural_logic_net.net(test_net)
80-
soft_weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
81-
hard_weights = harden.hard_weights(soft_weights)
82-
symbolic_weights = harden.symbolic_weights(soft_weights)
79+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
80+
hard_weights = harden.hard_weights(weights)
81+
8382
test_data = [
8483
[
8584
[1.0, 1.0],
@@ -99,24 +98,28 @@ def test_net(type, x):
9998
]
10099
]
101100
for input, expected in test_data:
102-
soft_input = jnp.array(input)
103-
soft_expected = jnp.array(expected)
104-
soft_result = soft.apply(soft_weights, soft_input)
105-
assert jnp.allclose(soft_result, soft_expected)
106-
hard_input = harden.harden(soft_input)
107-
hard_expected = harden.harden(soft_expected)
108-
hard_result = hard.apply(hard_weights, hard_input)
109-
assert jnp.allclose(hard_result, hard_expected)
110-
symbolic_result = symbolic.apply(symbolic_weights, hard_input.tolist())
111-
assert jnp.array_equal(symbolic_result, hard_expected)
101+
# Check that the soft function performs as expected
102+
assert jax.numpy.allclose(soft.apply(
103+
weights, jax.numpy.array(input)), jax.numpy.array(expected))
104+
105+
# Check that the hard function performs as expected
106+
hard_input = harden.harden(jax.numpy.array(input))
107+
hard_expected = harden.harden(jax.numpy.array(expected))
108+
hard_output = hard.apply(hard_weights, hard_input)
109+
assert jax.numpy.allclose(hard_output, hard_expected)
110+
111+
# Check that the symbolic function performs as expected
112+
symbolic_output = symbolic.apply(hard_weights, hard_input)
113+
assert numpy.allclose(symbolic_output, hard_expected)
112114

113115

114116
def test_train_or():
115117
def test_net(type, x):
116118
return hard_or.or_layer(type)(4, nn.initializers.uniform(1.0))(x)
117119

118120
soft, hard, symbolic = neural_logic_net.net(test_net)
119-
soft_weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
121+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
122+
120123
x = [
121124
[1.0, 1.0],
122125
[1.0, 0.0],
@@ -129,30 +132,31 @@ def test_net(type, x):
129132
[1.0, 0.0, 0.0, 1.0],
130133
[0.0, 0.0, 0.0, 0.0]
131134
]
132-
input = jnp.array(x)
133-
output = jnp.array(y)
135+
input = jax.numpy.array(x)
136+
output = jax.numpy.array(y)
134137

135138
# Train the and layer
136139
tx = optax.sgd(0.1)
137140
state = train_state.TrainState.create(apply_fn=jax.vmap(
138-
soft.apply, in_axes=(None, 0)), params=soft_weights, tx=tx)
141+
soft.apply, in_axes=(None, 0)), params=weights, tx=tx)
139142
grad_fn = jax.jit(jax.value_and_grad(lambda params, x,
140-
y: jnp.mean((state.apply_fn(params, x) - y) ** 2)))
143+
y: jax.numpy.mean((state.apply_fn(params, x) - y) ** 2)))
141144
for epoch in range(1, 100):
142145
loss, grads = grad_fn(state.params, input, output)
143146
state = state.apply_gradients(grads=grads)
144147

145148
# Test that the and layer (both soft and hard variants) correctly predicts y
146-
soft_weights = state.params
147-
hard_weights = harden.hard_weights(soft_weights)
148-
symbolic_weights = harden.symbolic_weights(soft_weights)
149+
weights = state.params
150+
hard_weights = harden.hard_weights(weights)
151+
149152
for input, expected in zip(x, y):
150-
hard_input = harden.harden_array(harden.harden(jnp.array(input)))
151-
hard_expected = harden.harden_array(harden.harden(jnp.array(expected)))
153+
hard_input = harden.harden(jax.numpy.array(input))
154+
hard_expected = harden.harden(jax.numpy.array(expected))
152155
hard_result = hard.apply(hard_weights, hard_input)
153-
assert jnp.allclose(hard_result, hard_expected)
154-
symbolic_result = symbolic.apply(symbolic_weights, hard_input.tolist())
155-
assert jnp.array_equal(symbolic_result, hard_expected)
156+
assert jax.numpy.allclose(hard_result, hard_expected)
157+
symbolic_output = symbolic.apply(hard_weights, hard_input)
158+
assert jax.numpy.array_equal(symbolic_output, hard_expected)
159+
156160

157161

158162
def test_symbolic_or():
@@ -162,9 +166,46 @@ def test_net(type, x):
162166
return x
163167

164168
soft, hard, symbolic = neural_logic_net.net(test_net)
165-
soft_weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
166-
symbolic_weights = harden.symbolic_weights(soft_weights)
169+
170+
# Compute soft result
171+
soft_input = jax.numpy.array([0.6, 0.45])
172+
weights = soft.init(random.PRNGKey(0), soft_input)
173+
soft_result = soft.apply(weights, numpy.array(soft_input))
174+
175+
# Compute hard result
176+
hard_weights = harden.hard_weights(weights)
177+
hard_input = harden.harden(soft_input)
178+
hard_result = hard.apply(hard_weights, numpy.array(hard_input))
179+
# Check that the hard result is the same as the soft result
180+
assert numpy.array_equal(harden.harden(soft_result), hard_result)
181+
182+
# Compute symbolic result with non-symbolic inputs
183+
symbolic_output = symbolic.apply(hard_weights, hard_input)
184+
# Check that the symbolic result is the same as the hard result
185+
assert numpy.array_equal(symbolic_output, hard_result)
186+
187+
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
188+
symbolic_input = ['True', 'False']
189+
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
190+
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
191+
symbolic_output = symbolic_generation.eval_symbolic_expression(symbolic_output)
192+
# Check that the symbolic result is the same as the hard result
193+
assert numpy.array_equal(symbolic_output, hard_result)
194+
195+
# Compute symbolic result with symbolic inputs and non-symbolic weights
167196
symbolic_input = ['x1', 'x2']
168-
symbolic_result = symbolic.apply(symbolic_weights, symbolic_input)
169-
assert (symbolic_result == ['((((x1 and False) or (x2 and False)) and True) or (((x1 and False) or (x2 and False)) and True) or (((x1 and False) or (x2 and True)) and True) or (((x1 and True) or (x2 and True)) and True))', '((((x1 and False) or (x2 and False)) and False) or (((x1 and False) or (x2 and False)) and True) or (((x1 and False) or (x2 and True)) and False) or (((x1 and True) or (x2 and True)) and False))',
170-
'((((x1 and False) or (x2 and False)) and False) or (((x1 and False) or (x2 and False)) and False) or (((x1 and False) or (x2 and True)) and True) or (((x1 and True) or (x2 and True)) and True))', '((((x1 and False) or (x2 and False)) and False) or (((x1 and False) or (x2 and False)) and True) or (((x1 and False) or (x2 and True)) and True) or (((x1 and True) or (x2 and True)) and True))'])
197+
symbolic_output = symbolic.apply(hard_weights, symbolic_input)
198+
# Check the form of the symbolic expression
199+
assert numpy.array_equal(symbolic_output, ['((((False or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and True) or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and True) or (((False or (x1 != 0) and False) or (x2 != 0) and True) != 0) and True) or (((False or (x1 != 0) and True) or (x2 != 0) and True) != 0) and True)',
200+
'((((False or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and False) or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and True) or (((False or (x1 != 0) and False) or (x2 != 0) and True) != 0) and False) or (((False or (x1 != 0) and True) or (x2 != 0) and True) != 0) and False)',
201+
'((((False or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and False) or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and False) or (((False or (x1 != 0) and False) or (x2 != 0) and True) != 0) and True) or (((False or (x1 != 0) and True) or (x2 != 0) and True) != 0) and True)',
202+
'((((False or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and False) or (((False or (x1 != 0) and False) or (x2 != 0) and False) != 0) and True) or (((False or (x1 != 0) and False) or (x2 != 0) and True) != 0) and True) or (((False or (x1 != 0) and True) or (x2 != 0) and True) != 0) and True)'])
203+
204+
# Compute symbolic result with symbolic inputs and symbolic weights
205+
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
206+
# Check the form of the symbolic expression
207+
assert numpy.array_equal(symbolic_output, ['((((False or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0))',
208+
'((((False or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (False != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) and (False != 0)) or (((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) and (False != 0))',
209+
'((((False or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (False != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (False != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0))',
210+
'((((False or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (False != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0)) or (((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) and (True != 0))'])
211+

0 commit comments

Comments
 (0)