1
- from functools import reduce
2
- from typing import Callable
1
+ from typing import Any
3
2
3
+ import numpy
4
4
import jax
5
5
from flax import linen as nn
6
+ from typing import Callable
7
+
6
8
7
- from neurallogic import neural_logic_net
9
+ from neurallogic import neural_logic_net , symbolic_generation
8
10
9
11
10
12
def soft_and_include (w : float , x : float ) -> float :
11
13
"""
12
14
w > 0.5 implies the and operation is active, else inactive
13
15
14
16
Assumes x is in [0, 1]
15
-
17
+
16
18
Corresponding hard logic: x OR ! w
17
19
"""
18
20
w = jax .numpy .clip (w , 0.0 , 1.0 )
19
21
return jax .numpy .maximum (x , 1.0 - w )
20
22
21
- @jax .jit
22
- def hard_and_include (w : bool , x : bool ) -> bool :
23
- return x | ~ w
24
23
25
- def symbolic_and_include (w , x ):
26
- expression = f"({ x } or not({ w } ))"
27
- # Check if w is of type bool
28
- if isinstance (w , bool ) and isinstance (x , bool ):
29
- # We know the value of w and x, so we can evaluate the expression
30
- return eval (expression )
31
- # We don't know the value of w or x, so we return the expression
32
- return expression
24
+
25
+ def hard_and_include (w , x ):
26
+ return jax .numpy .logical_or (x , jax .numpy .logical_not (w ))
27
+
28
+
33
29
34
30
def soft_and_neuron (w , x ):
35
31
x = jax .vmap (soft_and_include , 0 , 0 )(w , x )
36
32
return jax .numpy .min (x )
37
33
34
+
38
35
def hard_and_neuron (w , x ):
39
36
x = jax .vmap (hard_and_include , 0 , 0 )(w , x )
40
37
return jax .lax .reduce (x , True , jax .lax .bitwise_and , [0 ])
41
38
42
- def symbolic_and_neuron (w , x ):
43
- # TODO: ensure that this implementation has the same generality over tensors as vmap
44
- if not isinstance (w , list ):
45
- raise TypeError (f"Input { x } should be a list" )
46
- if not isinstance (x , list ):
47
- raise TypeError (f"Input { x } should be a list" )
48
- y = [symbolic_and_include (wi , xi ) for wi , xi in zip (w , x )]
49
- expression = "(" + str (reduce (lambda a , b : f"{ a } and { b } " , y )) + ")"
50
- if all (isinstance (yi , bool ) for yi in y ):
51
- # We know the value of all yis, so we can evaluate the expression
52
- return eval (expression )
53
- return expression
54
39
55
40
soft_and_layer = jax .vmap (soft_and_neuron , (0 , None ), 0 )
56
41
57
42
hard_and_layer = jax .vmap (hard_and_neuron , (0 , None ), 0 )
58
43
59
- def symbolic_and_layer (w , x ):
60
- # TODO: ensure that this implementation has the same generality over tensors as vmap
61
- if not isinstance (w , list ):
62
- raise TypeError (f"Input { x } should be a list" )
63
- if not isinstance (x , list ):
64
- raise TypeError (f"Input { x } should be a list" )
65
- return [symbolic_and_neuron (wi , x ) for wi in w ]
66
44
67
- # TODO: investigate better initialization
68
45
def initialize_near_to_zero ():
46
+ # TODO: investigate better initialization
69
47
def init (key , shape , dtype ):
70
48
dtype = jax .dtypes .canonicalize_dtype (dtype )
71
49
# Sample from standard normal distribution (zero mean, unit variance)
@@ -76,6 +54,7 @@ def init(key, shape, dtype):
76
54
return x
77
55
return init
78
56
57
+
79
58
class SoftAndLayer (nn .Module ):
80
59
"""
81
60
A soft-bit AND layer than transforms its inputs along the last dimension.
@@ -91,10 +70,12 @@ class SoftAndLayer(nn.Module):
91
70
@nn .compact
92
71
def __call__ (self , x ):
93
72
weights_shape = (self .layer_size , jax .numpy .shape (x )[- 1 ])
94
- weights = self .param ('weights' , self .weights_init , weights_shape , self .dtype )
73
+ weights = self .param ('weights' , self .weights_init ,
74
+ weights_shape , self .dtype )
95
75
x = jax .numpy .asarray (x , self .dtype )
96
76
return soft_and_layer (weights , x )
97
77
78
+
98
79
class HardAndLayer (nn .Module ):
99
80
"""
100
81
A hard-bit And layer that shadows the SoftAndLayer.
@@ -108,26 +89,22 @@ class HardAndLayer(nn.Module):
108
89
@nn .compact
109
90
def __call__ (self , x ):
110
91
weights_shape = (self .layer_size , jax .numpy .shape (x )[- 1 ])
111
- weights = self .param ('weights' , nn .initializers .constant (0.0 ), weights_shape )
92
+ weights = self .param (
93
+ 'weights' , nn .initializers .constant (0.0 ), weights_shape )
112
94
return hard_and_layer (weights , x )
113
95
114
- class SymbolicAndLayer (nn .Module ):
115
- """A symbolic And layer than transforms its inputs along the last dimension.
116
- Attributes:
117
- layer_size: The number of neurons in the layer.
118
- """
119
- layer_size : int
120
96
121
- @nn .compact
97
+ class SymbolicAndLayer :
98
+ def __init__ (self , layer_size ):
99
+ self .layer_size = layer_size
100
+ self .hard_and_layer = HardAndLayer (self .layer_size )
101
+
122
102
def __call__ (self , x ):
123
- weights_shape = (self .layer_size , jax .numpy .shape (x )[- 1 ])
124
- weights = self .param ('weights' , nn .initializers .constant (0.0 ), weights_shape )
125
- weights = weights .tolist ()
126
- if not isinstance (x , list ):
127
- raise TypeError (f"Input { x } should be a list" )
128
- return symbolic_and_layer (weights , x )
103
+ jaxpr = symbolic_generation .make_symbolic_flax_jaxpr (self .hard_and_layer , x )
104
+ return symbolic_generation .symbolic_expression (jaxpr , x )
105
+
129
106
130
107
and_layer = neural_logic_net .select (
131
- lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : SoftAndLayer (layer_size , weights_init , dtype ),
132
- lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : HardAndLayer (layer_size ),
133
- lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : SymbolicAndLayer (layer_size ))
108
+ lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : SoftAndLayer (layer_size , weights_init , dtype ),
109
+ lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : HardAndLayer (layer_size ),
110
+ lambda layer_size , weights_init = initialize_near_to_zero (), dtype = jax .numpy .float32 : SymbolicAndLayer (layer_size ))
0 commit comments