7
7
import typing
8
8
9
9
from neurallogic import hard_and , harden , neural_logic_net , symbolic_generation
10
+ from tests import utils
10
11
11
12
12
- def check_consistency (soft : typing .Callable , hard : typing .Callable , expected , * args ):
13
- # Check that the soft function performs as expected
14
- assert numpy .allclose (soft (* args ), expected )
15
-
16
- # Check that the hard function performs as expected
17
- hard_args = harden .harden (* args )
18
- hard_expected = harden .harden (expected )
19
- assert numpy .allclose (hard (* hard_args ), hard_expected )
20
-
21
- # Check that the jaxpr performs as expected
22
- symbolic_f = symbolic_generation .make_symbolic_jaxpr (hard , * hard_args )
23
- assert numpy .allclose (symbolic_generation .eval_symbolic (
24
- symbolic_f , * hard_args ), hard_expected )
25
-
26
13
27
14
def test_include ():
28
15
test_data = [
@@ -36,7 +23,7 @@ def test_include():
36
23
[[- 0.1 , 1.0 ], 1.0 ]
37
24
]
38
25
for input , expected in test_data :
39
- check_consistency (hard_and .soft_and_include , hard_and .hard_and_include ,
26
+ utils . check_consistency (hard_and .soft_and_include , hard_and .hard_and_include ,
40
27
expected , input [0 ], input [1 ])
41
28
42
29
@@ -56,7 +43,7 @@ def soft(weights, input):
56
43
def hard (weights , input ):
57
44
return hard_and .hard_and_neuron (weights , input )
58
45
59
- check_consistency (soft , hard , expected ,
46
+ utils . check_consistency (soft , hard , expected ,
60
47
jax .numpy .array (weights ), jax .numpy .array (input ))
61
48
62
49
@@ -81,7 +68,7 @@ def hard(weights, input):
81
68
return hard_and .hard_and_layer (weights , input )
82
69
83
70
84
- check_consistency (soft , hard , jax .numpy .array (expected ),
71
+ utils . check_consistency (soft , hard , jax .numpy .array (expected ),
85
72
jax .numpy .array (weights ), jax .numpy .array (input ))
86
73
87
74
0 commit comments