1
- from neurallogic import hard_masks
1
+ import jax
2
+ import numpy
3
+ from jax import random
4
+
5
+ from neurallogic import hard_masks , harden , neural_logic_net
2
6
from tests import utils
3
7
4
8
@@ -23,6 +27,122 @@ def test_mask_to_true():
23
27
)
24
28
25
29
30
+ def test_mask_to_true_neuron ():
31
+ test_data = [
32
+ [[1.0 , 1.0 ], [1.0 , 1.0 ], [1.0 , 1.0 ]],
33
+ [[0.0 , 0.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ]],
34
+ [[1.0 , 0.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]],
35
+ [[0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ]],
36
+ [[0.0 , 1.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ]],
37
+ [[0.0 , 1.0 ], [1.0 , 1.0 ], [0.0 , 1.0 ]],
38
+ ]
39
+ for input , weights , expected in test_data :
40
+
41
+ def soft (weights , input ):
42
+ return hard_masks .soft_mask_to_true_neuron (weights , input )
43
+
44
+ def hard (weights , input ):
45
+ return hard_masks .hard_mask_to_true_neuron (weights , input )
46
+
47
+ utils .check_consistency (
48
+ soft , hard , expected , jax .numpy .array (weights ), jax .numpy .array (input )
49
+ )
50
+
51
+
52
+ def test_mask_to_true_layer ():
53
+ test_data = [
54
+ [
55
+ [1.0 , 0.0 ],
56
+ [[1.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.2 ]],
57
+ [[1.0 , 0.0 ], [1.0 , 0.0 ], [1.0 , 1.0 ], [1.0 , 0.8 ]],
58
+ ],
59
+ [
60
+ [1.0 , 0.4 ],
61
+ [[1.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
62
+ [[1.0 , 0.4 ], [1.0 , 0.4 ], [1.0 , 1.0 ], [1.0 , 1.0 ]],
63
+ ],
64
+ [
65
+ [0.0 , 1.0 ],
66
+ [[1.0 , 1.0 ], [0.0 , 0.8 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
67
+ [[0.0 , 1.0 ], [1.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 1.0 ]],
68
+ ],
69
+ [
70
+ [0.0 , 0.0 ],
71
+ [[1.0 , 0.01 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
72
+ [[0.0 , 0.99 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [1.0 , 1.0 ]],
73
+ ],
74
+ ]
75
+ for input , weights , expected in test_data :
76
+
77
+ def soft (weights , input ):
78
+ return hard_masks .soft_mask_to_true_layer (weights , input )
79
+
80
+ def hard (weights , input ):
81
+ return hard_masks .hard_mask_to_true_layer (weights , input )
82
+
83
+ utils .check_consistency (
84
+ soft ,
85
+ hard ,
86
+ jax .numpy .array (expected ),
87
+ jax .numpy .array (weights ),
88
+ jax .numpy .array (input ),
89
+ )
90
+
91
+
92
+ def test_mask_to_true ():
93
+ def test_net (type , x ):
94
+ x = hard_masks .mask_to_true_layer (type )(4 )(x )
95
+ x = x .ravel ()
96
+ return x
97
+
98
+ soft , hard , symbolic = neural_logic_net .net (test_net )
99
+ weights = soft .init (random .PRNGKey (0 ), [0.0 , 0.0 ])
100
+ hard_weights = harden .hard_weights (weights )
101
+
102
+ test_data = [
103
+ [
104
+ [1.0 , 1.0 ],
105
+ [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
106
+ ],
107
+ [
108
+ [1.0 , 0.0 ],
109
+ [1.0 , 0.17739451 , 1.0 , 0.77752244 , 1.0 , 0.11280203 , 1.0 , 0.43465567 ],
110
+ ],
111
+ [
112
+ [0.0 , 1.0 ],
113
+ [0.6201445 , 1.0 , 0.7178699 , 1.0 , 0.29197645 , 1.0 , 0.41213453 , 1.0 ],
114
+ ],
115
+ [
116
+ [0.0 , 0.0 ],
117
+ [
118
+ 0.6201445 ,
119
+ 0.17739451 ,
120
+ 0.7178699 ,
121
+ 0.77752244 ,
122
+ 0.29197645 ,
123
+ 0.11280203 ,
124
+ 0.41213453 ,
125
+ 0.43465567 ,
126
+ ],
127
+ ],
128
+ ]
129
+ for input , expected in test_data :
130
+ # Check that the soft function performs as expected
131
+ soft_output = soft .apply (weights , jax .numpy .array (input ))
132
+ expected_output = jax .numpy .array (expected )
133
+ assert jax .numpy .allclose (soft_output , expected_output )
134
+
135
+ # Check that the hard function performs as expected
136
+ hard_input = harden .harden (jax .numpy .array (input ))
137
+ hard_expected = harden .harden (jax .numpy .array (expected ))
138
+ hard_output = hard .apply (hard_weights , hard_input )
139
+ assert jax .numpy .allclose (hard_output , hard_expected )
140
+
141
+ # Check that the symbolic function performs as expected
142
+ symbolic_output = symbolic .apply (hard_weights , hard_input )
143
+ assert numpy .allclose (symbolic_output , hard_expected )
144
+
145
+
26
146
def test_mask_to_false ():
27
147
test_data = [
28
148
[[1.0 , 1.0 ], 1.0 ],
@@ -42,3 +162,119 @@ def test_mask_to_false():
42
162
input [0 ],
43
163
input [1 ],
44
164
)
165
+
166
+
167
+ def test_mask_to_false_neuron ():
168
+ test_data = [
169
+ [[1.0 , 1.0 ], [1.0 , 1.0 ], [1.0 , 1.0 ]],
170
+ [[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
171
+ [[1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 0.0 ]],
172
+ [[0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
173
+ [[0.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
174
+ [[0.0 , 1.0 ], [1.0 , 1.0 ], [0.0 , 1.0 ]],
175
+ ]
176
+ for input , weights , expected in test_data :
177
+
178
+ def soft (weights , input ):
179
+ return hard_masks .soft_mask_to_false_neuron (weights , input )
180
+
181
+ def hard (weights , input ):
182
+ return hard_masks .hard_mask_to_false_neuron (weights , input )
183
+
184
+ utils .check_consistency (
185
+ soft , hard , expected , jax .numpy .array (weights ), jax .numpy .array (input )
186
+ )
187
+
188
+
189
+ def test_mask_to_false_layer ():
190
+ test_data = [
191
+ [
192
+ [1.0 , 0.0 ],
193
+ [[1.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.2 ]],
194
+ [[1.0 , 0.0 ], [0.0 , 0.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
195
+ ],
196
+ [
197
+ [1.0 , 0.4 ],
198
+ [[1.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
199
+ [[1.0 , 0.39999998 ], [0.0 , 0.39999998 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
200
+ ],
201
+ [
202
+ [0.0 , 1.0 ],
203
+ [[1.0 , 1.0 ], [0.0 , 0.8 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
204
+ [[0.0 , 1.0 ], [0.0 , 0.8 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
205
+ ],
206
+ [
207
+ [0.0 , 0.0 ],
208
+ [[1.0 , 0.01 ], [0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ]],
209
+ [[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
210
+ ],
211
+ ]
212
+ for input , weights , expected in test_data :
213
+
214
+ def soft (weights , input ):
215
+ return hard_masks .soft_mask_to_false_layer (weights , input )
216
+
217
+ def hard (weights , input ):
218
+ return hard_masks .hard_mask_to_false_layer (weights , input )
219
+
220
+ utils .check_consistency (
221
+ soft ,
222
+ hard ,
223
+ jax .numpy .array (expected ),
224
+ jax .numpy .array (weights ),
225
+ jax .numpy .array (input ),
226
+ )
227
+
228
+
229
+ def test_mask_to_false ():
230
+ def test_net (type , x ):
231
+ x = hard_masks .mask_to_false_layer (type )(4 )(x )
232
+ x = x .ravel ()
233
+ return x
234
+
235
+ soft , hard , symbolic = neural_logic_net .net (test_net )
236
+ weights = soft .init (random .PRNGKey (0 ), [0.0 , 0.0 ])
237
+ hard_weights = harden .hard_weights (weights )
238
+
239
+ test_data = [
240
+ [
241
+ [1.0 , 1.0 ],
242
+ [
243
+ 0.3798555 ,
244
+ 0.8226055 ,
245
+ 0.28213012 ,
246
+ 0.22247756 ,
247
+ 0.70802355 ,
248
+ 0.887198 ,
249
+ 0.5878655 ,
250
+ 0.56534433 ,
251
+ ],
252
+ ],
253
+ [
254
+ [1.0 , 0.0 ],
255
+ [0.3798555 , 0.0 , 0.28213012 , 0.0 , 0.70802355 , 0.0 , 0.5878655 , 0.0 ],
256
+ ],
257
+ [
258
+ [0.0 , 1.0 ],
259
+ [0.0 , 0.8226055 , 0.0 , 0.22247756 , 0.0 , 0.887198 , 0.0 , 0.56534433 ]
260
+ ],
261
+ [
262
+ [0.0 , 0.0 ],
263
+ [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ],
264
+ ],
265
+ ]
266
+ for input , expected in test_data :
267
+ # Check that the soft function performs as expected
268
+ soft_output = soft .apply (weights , jax .numpy .array (input ))
269
+ expected_output = jax .numpy .array (expected )
270
+ assert jax .numpy .allclose (soft_output , expected_output )
271
+
272
+ # Check that the hard function performs as expected
273
+ hard_input = harden .harden (jax .numpy .array (input ))
274
+ hard_expected = harden .harden (jax .numpy .array (expected ))
275
+ hard_output = hard .apply (hard_weights , hard_input )
276
+ assert jax .numpy .allclose (hard_output , hard_expected )
277
+
278
+ # Check that the symbolic function performs as expected
279
+ symbolic_output = symbolic .apply (hard_weights , hard_input )
280
+ assert numpy .allclose (symbolic_output , hard_expected )
0 commit comments