3
3
import jax
4
4
from flax import linen as nn
5
5
6
- from neurallogic import neural_logic_net
6
+ from neurallogic import neural_logic_net , symbolic_generation
7
7
8
8
9
9
def soft_not (w : float , x : float ) -> float :
@@ -18,57 +18,24 @@ def soft_not(w: float, x: float) -> float:
18
18
return 1.0 - w + x * (2.0 * w - 1.0 )
19
19
20
20
21
- @jax .jit
22
21
def hard_not (w : bool , x : bool ) -> bool :
23
- return ~ (x ^ w )
24
-
25
-
26
- def symbolic_not (w , x ):
27
- expression = f"(not({ x } ^ { w } ))"
28
- # Check if w is of type bool
29
- if isinstance (w , bool ) and isinstance (x , bool ):
30
- # We know the value of w and x, so we can evaluate the expression
31
- return eval (expression )
32
- # We don't know the value of w or x, so we return the expression
33
- return expression
22
+ return jax .numpy .logical_not (jax .numpy .logical_xor (x , w ))
34
23
35
24
36
25
soft_not_neuron = jax .vmap (soft_not , 0 , 0 )
37
26
38
27
hard_not_neuron = jax .vmap (hard_not , 0 , 0 )
39
28
40
29
41
- def symbolic_not_neuron (w , x ):
42
- # TODO: ensure that this implementation has the same generality over tensors as vmap
43
- if not isinstance (w , list ):
44
- raise TypeError (f"Input { x } should be a list" )
45
- if not isinstance (x , list ):
46
- raise TypeError (f"Input { x } should be a list" )
47
- return [symbolic_not (wi , xi ) for wi , xi in zip (w , x )]
48
-
49
30
50
31
soft_not_layer = jax .vmap (soft_not_neuron , (0 , None ), 0 )
51
32
52
33
hard_not_layer = jax .vmap (hard_not_neuron , (0 , None ), 0 )
53
34
54
35
55
- def symbolic_not_layer (w , x ):
56
- # TODO: ensure that this implementation has the same generality over tensors as vmap
57
- if not isinstance (w , list ):
58
- raise TypeError (f"Input { x } should be a list" )
59
- if not isinstance (x , list ):
60
- raise TypeError (f"Input { x } should be a list" )
61
- return [symbolic_not_neuron (wi , x ) for wi in w ]
62
36
63
37
64
38
class SoftNotLayer (nn .Module ):
65
- """
66
- A soft-bit NOT layer than transforms its inputs along the last dimension.
67
-
68
- Attributes:
69
- layer_size: The number of neurons in the layer.
70
- weights_init: The initializer function for the weight matrix.
71
- """
72
39
layer_size : int
73
40
weights_init : Callable = nn .initializers .uniform (1.0 )
74
41
dtype : jax .numpy .dtype = jax .numpy .float32
@@ -83,13 +50,6 @@ def __call__(self, x):
83
50
84
51
85
52
class HardNotLayer (nn .Module ):
86
- """
87
- A hard-bit NOT layer that shadows the SoftNotLayer.
88
- This is a convenience class to make it easier to switch between soft and hard logic.
89
-
90
- Attributes:
91
- layer_size: The number of neurons in the layer.
92
- """
93
53
layer_size : int
94
54
95
55
@nn .compact
@@ -100,22 +60,14 @@ def __call__(self, x):
100
60
return hard_not_layer (weights , x )
101
61
102
62
103
- class SymbolicNotLayer (nn .Module ):
104
- """A symbolic NOT layer than transforms its inputs along the last dimension.
105
- Attributes:
106
- layer_size: The number of neurons in the layer.
107
- """
108
- layer_size : int
63
+ class SymbolicNotLayer :
64
+ def __init__ (self , layer_size ):
65
+ self .layer_size = layer_size
66
+ self .hard_not_layer = HardNotLayer (self .layer_size )
109
67
110
- @nn .compact
111
68
def __call__ (self , x ):
112
- weights_shape = (self .layer_size , jax .numpy .shape (x )[- 1 ])
113
- weights = self .param (
114
- 'weights' , nn .initializers .constant (0.0 ), weights_shape )
115
- weights = weights .tolist ()
116
- if not isinstance (x , list ):
117
- raise TypeError (f"Input { x } should be a list" )
118
- return symbolic_not_layer (weights , x )
69
+ jaxpr = symbolic_generation .make_symbolic_flax_jaxpr (self .hard_not_layer , x )
70
+ return symbolic_generation .symbolic_expression (jaxpr , x )
119
71
120
72
121
73
not_layer = neural_logic_net .select (
0 commit comments