@@ -119,8 +119,6 @@ def test_net(type, x):
119
119
soft , hard , symbolic = neural_logic_net .net (test_net )
120
120
weights = soft .init (random .PRNGKey (0 ), [0.0 , 0.0 ])
121
121
hard_weights = harden .hard_weights (weights )
122
- print (f'weights: { weights } ' )
123
- print (f'hard_weights: { hard_weights } ' )
124
122
125
123
test_data = [
126
124
[
@@ -133,63 +131,36 @@ def test_net(type, x):
133
131
[
134
132
[0.6 , 0.0 ],
135
133
[
136
- [
137
- 0.9469013 ,
138
- 0.320184 ,
139
- 0.3194083 ,
140
- ],
141
- [
142
- 0.58414006 ,
143
- 0.7815013 ,
144
- 0.04193211 ,
145
- ],
134
+ [0.78442293 , 0.7857669 , 0.3154459 ],
135
+ [0.0 , 0.0 , 0.0 ],
146
136
],
147
137
],
148
138
[
149
139
[0.1 , 0.9 ],
150
140
[
151
- [
152
- 0.05309868 ,
153
- 0.679816 ,
154
- 0.6805917 ,
155
- ],
156
- [
157
- 0.41585994 ,
158
- 0.2184987 ,
159
- 0.9580679 ,
160
- ],
141
+ [0.5149515 , 0.51797545 , 0.05257431 ],
142
+ [0.69679934 , 0.629154 , 0.84623945 ],
161
143
],
162
144
],
163
145
[
164
146
[0.4 , 0.6 ],
165
147
[
166
- [
167
- 0.05309868 ,
168
- 0.320184 ,
169
- 0.6805917 ,
170
- ],
171
- [
172
- 0.58414006 ,
173
- 0.2184987 ,
174
- 0.04193211 ,
175
- ],
148
+ [0.6766343 , 0.67865026 , 0.21029726 ],
149
+ [0.35924158 , 0.34675142 , 0.4445637 ],
176
150
],
177
151
],
178
152
]
179
153
for input , expected in test_data :
180
154
# Check that the soft function performs as expected
181
155
soft_output = soft .apply (weights , jax .numpy .array (input ))
182
156
soft_expected = jax .numpy .array (expected )
183
- print (f'soft_output: { soft_output } \n soft_expected: { soft_expected } ' )
184
157
assert jax .numpy .allclose (soft_output , soft_expected )
185
158
186
159
# Check that the hard function performs as expected
187
160
hard_expected = harden .harden (jax .numpy .array (expected ))
188
161
hard_output = hard .apply (hard_weights , jax .numpy .array (input ))
189
- print (f'hard_output: { hard_output } \n hard_expected: { hard_expected } ' )
190
162
assert jax .numpy .allclose (hard_output , hard_expected )
191
163
192
164
# Check that the symbolic function performs as expected
193
165
symbolic_output = symbolic .apply (hard_weights , jax .numpy .array (input ))
194
- print (f'symbolic_output: { symbolic_output } \n hard_expected: { hard_expected } ' )
195
166
assert numpy .allclose (symbolic_output , hard_expected )
0 commit comments