Skip to content

Commit 2b30340

Browse files
committed
more tests for mask layers
1 parent a0f7af7 commit 2b30340

File tree

4 files changed

+413
-80
lines changed

4 files changed

+413
-80
lines changed

neurallogic/hard_masks.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from typing import Callable
2+
13
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
27

38

49
def soft_mask_to_true(w: float, x: float) -> float:
510
"""
6-
w > 0.5 implies the and operation is active, else inactive
11+
w > 0.5 implies the mask operation is inactive, else active
712
813
Assumes x is in [0, 1]
914
@@ -17,9 +22,19 @@ def hard_mask_to_true(w, x):
1722
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
1823

1924

25+
soft_mask_to_true_neuron = jax.vmap(soft_mask_to_true, 0, 0)
26+
27+
hard_mask_to_true_neuron = jax.vmap(hard_mask_to_true, 0, 0)
28+
29+
30+
soft_mask_to_true_layer = jax.vmap(soft_mask_to_true_neuron, (0, None), 0)
31+
32+
hard_mask_to_true_layer = jax.vmap(hard_mask_to_true_neuron, (0, None), 0)
33+
34+
2035
def soft_mask_to_false(w: float, x: float) -> float:
2136
"""
22-
w > 0.5 implies the and operation is active, else inactive
37+
w > 0.5 implies the mask is inactive, else active
2338
2439
Assumes x is in [0, 1]
2540
@@ -31,3 +46,84 @@ def soft_mask_to_false(w: float, x: float) -> float:
3146

3247
def hard_mask_to_false(w, x):
3348
return jax.numpy.logical_and(x, w)
49+
50+
51+
soft_mask_to_false_neuron = jax.vmap(soft_mask_to_false, 0, 0)
52+
53+
hard_mask_to_false_neuron = jax.vmap(hard_mask_to_false, 0, 0)
54+
55+
56+
soft_mask_to_false_layer = jax.vmap(soft_mask_to_false_neuron, (0, None), 0)
57+
58+
hard_mask_to_false_layer = jax.vmap(hard_mask_to_false_neuron, (0, None), 0)
59+
60+
61+
class SoftMaskLayer(nn.Module):
62+
mask_layer_operation: Callable
63+
layer_size: int
64+
weights_init: Callable = nn.initializers.uniform(1.0)
65+
dtype: jax.numpy.dtype = jax.numpy.float32
66+
67+
@nn.compact
68+
def __call__(self, x):
69+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
70+
weights = self.param(
71+
"bit_weights", self.weights_init, weights_shape, self.dtype
72+
)
73+
x = jax.numpy.asarray(x, self.dtype)
74+
return self.mask_layer_operation(weights, x)
75+
76+
77+
class HardMaskLayer(nn.Module):
78+
mask_layer_operation: Callable
79+
layer_size: int
80+
weights_init: Callable = nn.initializers.constant(True)
81+
82+
@nn.compact
83+
def __call__(self, x):
84+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
85+
weights = self.param("bit_weights", self.weights_init, weights_shape)
86+
return self.mask_layer_operation(weights, x)
87+
88+
89+
class SymbolicMaskLayer:
90+
def __init__(self, mask_layer):
91+
self.hard_mask_layer = mask_layer
92+
93+
def __call__(self, x):
94+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_mask_layer, x)
95+
return symbolic_generation.symbolic_expression(jaxpr, x)
96+
97+
98+
mask_to_true_layer = neural_logic_net.select(
99+
lambda layer_size, weights_init=nn.initializers.uniform(
100+
1.0
101+
), dtype=jax.numpy.float32: SoftMaskLayer(
102+
soft_mask_to_true_layer, layer_size, weights_init, dtype
103+
),
104+
lambda layer_size, weights_init=nn.initializers.uniform(
105+
1.0
106+
), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_true_layer, layer_size),
107+
lambda layer_size, weights_init=nn.initializers.uniform(
108+
1.0
109+
), dtype=jax.numpy.float32: SymbolicMaskLayer(
110+
HardMaskLayer(hard_mask_to_true_layer, layer_size)
111+
),
112+
)
113+
114+
115+
mask_to_false_layer = neural_logic_net.select(
116+
lambda layer_size, weights_init=nn.initializers.uniform(
117+
1.0
118+
), dtype=jax.numpy.float32: SoftMaskLayer(
119+
soft_mask_to_false_layer, layer_size, weights_init, dtype
120+
),
121+
lambda layer_size, weights_init=nn.initializers.uniform(
122+
1.0
123+
), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_false_layer, layer_size),
124+
lambda layer_size, weights_init=nn.initializers.uniform(
125+
1.0
126+
), dtype=jax.numpy.float32: SymbolicMaskLayer(
127+
HardMaskLayer(hard_mask_to_false_layer, layer_size)
128+
),
129+
)

tests/test_hard_masks.py

Lines changed: 237 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from neurallogic import hard_masks
1+
import jax
2+
import numpy
3+
from jax import random
4+
5+
from neurallogic import hard_masks, harden, neural_logic_net
26
from tests import utils
37

48

@@ -23,6 +27,122 @@ def test_mask_to_true():
2327
)
2428

2529

30+
def test_mask_to_true_neuron():
31+
test_data = [
32+
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
33+
[[0.0, 0.0], [0.0, 0.0], [1.0, 1.0]],
34+
[[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],
35+
[[0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],
36+
[[0.0, 1.0], [0.0, 0.0], [1.0, 1.0]],
37+
[[0.0, 1.0], [1.0, 1.0], [0.0, 1.0]],
38+
]
39+
for input, weights, expected in test_data:
40+
41+
def soft(weights, input):
42+
return hard_masks.soft_mask_to_true_neuron(weights, input)
43+
44+
def hard(weights, input):
45+
return hard_masks.hard_mask_to_true_neuron(weights, input)
46+
47+
utils.check_consistency(
48+
soft, hard, expected, jax.numpy.array(weights), jax.numpy.array(input)
49+
)
50+
51+
52+
def test_mask_to_true_layer():
53+
test_data = [
54+
[
55+
[1.0, 0.0],
56+
[[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.2]],
57+
[[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 0.8]],
58+
],
59+
[
60+
[1.0, 0.4],
61+
[[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]],
62+
[[1.0, 0.4], [1.0, 0.4], [1.0, 1.0], [1.0, 1.0]],
63+
],
64+
[
65+
[0.0, 1.0],
66+
[[1.0, 1.0], [0.0, 0.8], [1.0, 0.0], [0.0, 0.0]],
67+
[[0.0, 1.0], [1.0, 1.0], [0.0, 1.0], [1.0, 1.0]],
68+
],
69+
[
70+
[0.0, 0.0],
71+
[[1.0, 0.01], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]],
72+
[[0.0, 0.99], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
73+
],
74+
]
75+
for input, weights, expected in test_data:
76+
77+
def soft(weights, input):
78+
return hard_masks.soft_mask_to_true_layer(weights, input)
79+
80+
def hard(weights, input):
81+
return hard_masks.hard_mask_to_true_layer(weights, input)
82+
83+
utils.check_consistency(
84+
soft,
85+
hard,
86+
jax.numpy.array(expected),
87+
jax.numpy.array(weights),
88+
jax.numpy.array(input),
89+
)
90+
91+
92+
def test_mask_to_true():
93+
def test_net(type, x):
94+
x = hard_masks.mask_to_true_layer(type)(4)(x)
95+
x = x.ravel()
96+
return x
97+
98+
soft, hard, symbolic = neural_logic_net.net(test_net)
99+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
100+
hard_weights = harden.hard_weights(weights)
101+
102+
test_data = [
103+
[
104+
[1.0, 1.0],
105+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
106+
],
107+
[
108+
[1.0, 0.0],
109+
[1.0, 0.17739451, 1.0, 0.77752244, 1.0, 0.11280203, 1.0, 0.43465567],
110+
],
111+
[
112+
[0.0, 1.0],
113+
[0.6201445, 1.0, 0.7178699, 1.0, 0.29197645, 1.0, 0.41213453, 1.0],
114+
],
115+
[
116+
[0.0, 0.0],
117+
[
118+
0.6201445,
119+
0.17739451,
120+
0.7178699,
121+
0.77752244,
122+
0.29197645,
123+
0.11280203,
124+
0.41213453,
125+
0.43465567,
126+
],
127+
],
128+
]
129+
for input, expected in test_data:
130+
# Check that the soft function performs as expected
131+
soft_output = soft.apply(weights, jax.numpy.array(input))
132+
expected_output = jax.numpy.array(expected)
133+
assert jax.numpy.allclose(soft_output, expected_output)
134+
135+
# Check that the hard function performs as expected
136+
hard_input = harden.harden(jax.numpy.array(input))
137+
hard_expected = harden.harden(jax.numpy.array(expected))
138+
hard_output = hard.apply(hard_weights, hard_input)
139+
assert jax.numpy.allclose(hard_output, hard_expected)
140+
141+
# Check that the symbolic function performs as expected
142+
symbolic_output = symbolic.apply(hard_weights, hard_input)
143+
assert numpy.allclose(symbolic_output, hard_expected)
144+
145+
26146
def test_mask_to_false():
27147
test_data = [
28148
[[1.0, 1.0], 1.0],
@@ -42,3 +162,119 @@ def test_mask_to_false():
42162
input[0],
43163
input[1],
44164
)
165+
166+
167+
def test_mask_to_false_neuron():
168+
test_data = [
169+
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
170+
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
171+
[[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
172+
[[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]],
173+
[[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]],
174+
[[0.0, 1.0], [1.0, 1.0], [0.0, 1.0]],
175+
]
176+
for input, weights, expected in test_data:
177+
178+
def soft(weights, input):
179+
return hard_masks.soft_mask_to_false_neuron(weights, input)
180+
181+
def hard(weights, input):
182+
return hard_masks.hard_mask_to_false_neuron(weights, input)
183+
184+
utils.check_consistency(
185+
soft, hard, expected, jax.numpy.array(weights), jax.numpy.array(input)
186+
)
187+
188+
189+
def test_mask_to_false_layer():
190+
test_data = [
191+
[
192+
[1.0, 0.0],
193+
[[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.2]],
194+
[[1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0]],
195+
],
196+
[
197+
[1.0, 0.4],
198+
[[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]],
199+
[[1.0, 0.39999998], [0.0, 0.39999998], [1.0, 0.0], [0.0, 0.0]],
200+
],
201+
[
202+
[0.0, 1.0],
203+
[[1.0, 1.0], [0.0, 0.8], [1.0, 0.0], [0.0, 0.0]],
204+
[[0.0, 1.0], [0.0, 0.8], [0.0, 0.0], [0.0, 0.0]],
205+
],
206+
[
207+
[0.0, 0.0],
208+
[[1.0, 0.01], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]],
209+
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
210+
],
211+
]
212+
for input, weights, expected in test_data:
213+
214+
def soft(weights, input):
215+
return hard_masks.soft_mask_to_false_layer(weights, input)
216+
217+
def hard(weights, input):
218+
return hard_masks.hard_mask_to_false_layer(weights, input)
219+
220+
utils.check_consistency(
221+
soft,
222+
hard,
223+
jax.numpy.array(expected),
224+
jax.numpy.array(weights),
225+
jax.numpy.array(input),
226+
)
227+
228+
229+
def test_mask_to_false():
230+
def test_net(type, x):
231+
x = hard_masks.mask_to_false_layer(type)(4)(x)
232+
x = x.ravel()
233+
return x
234+
235+
soft, hard, symbolic = neural_logic_net.net(test_net)
236+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
237+
hard_weights = harden.hard_weights(weights)
238+
239+
test_data = [
240+
[
241+
[1.0, 1.0],
242+
[
243+
0.3798555,
244+
0.8226055,
245+
0.28213012,
246+
0.22247756,
247+
0.70802355,
248+
0.887198,
249+
0.5878655,
250+
0.56534433,
251+
],
252+
],
253+
[
254+
[1.0, 0.0],
255+
[0.3798555, 0.0, 0.28213012, 0.0, 0.70802355, 0.0, 0.5878655, 0.0],
256+
],
257+
[
258+
[0.0, 1.0],
259+
[0.0, 0.8226055, 0.0, 0.22247756, 0.0, 0.887198, 0.0, 0.56534433]
260+
],
261+
[
262+
[0.0, 0.0],
263+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
264+
],
265+
]
266+
for input, expected in test_data:
267+
# Check that the soft function performs as expected
268+
soft_output = soft.apply(weights, jax.numpy.array(input))
269+
expected_output = jax.numpy.array(expected)
270+
assert jax.numpy.allclose(soft_output, expected_output)
271+
272+
# Check that the hard function performs as expected
273+
hard_input = harden.harden(jax.numpy.array(input))
274+
hard_expected = harden.harden(jax.numpy.array(expected))
275+
hard_output = hard.apply(hard_weights, hard_input)
276+
assert jax.numpy.allclose(hard_output, hard_expected)
277+
278+
# Check that the symbolic function performs as expected
279+
symbolic_output = symbolic.apply(hard_weights, hard_input)
280+
assert numpy.allclose(symbolic_output, hard_expected)

0 commit comments

Comments
 (0)