Skip to content

Commit 27aff8c

Browse files
committed
refactor code
1 parent f6f65c2 commit 27aff8c

9 files changed

+196
-186
lines changed

neurallogic/harden.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import jax
33
import numpy
44
from plum import dispatch
5-
from neurallogic import symbolic_primitives
5+
6+
from neurallogic import map_at_elements
67

78

89
def harden_float(x: float) -> bool:
@@ -21,7 +22,7 @@ def harden(x: float):
2122

2223
@dispatch
2324
def harden(x: list):
24-
return symbolic_primitives.map_at_elements(x, harden_float)
25+
return map_at_elements.map_at_elements(x, harden_float)
2526

2627

2728
@dispatch
@@ -39,7 +40,7 @@ def harden(x: dict):
3940
# Only harden parameters that explicitly represent bits
4041
def conditional_harden(k, v):
4142
if k.startswith("bit_"):
42-
return symbolic_primitives.map_at_elements(v, harden)
43+
return map_at_elements.map_at_elements(v, harden)
4344
elif isinstance(v, dict) or isinstance(v, flax.core.FrozenDict) or isinstance(v, list):
4445
return harden(v)
4546
return v

neurallogic/map_at_elements.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import typing
2+
3+
import jax
4+
import numpy
5+
from plum import dispatch
6+
7+
8+
@dispatch
9+
def map_at_elements(x: str, func: typing.Callable):
10+
return func(x)
11+
12+
13+
@dispatch
14+
def map_at_elements(x: bool, func: typing.Callable):
15+
return func(x)
16+
17+
18+
@dispatch
19+
def map_at_elements(x: numpy.bool_, func: typing.Callable):
20+
return func(x)
21+
22+
23+
@dispatch
24+
def map_at_elements(x: float, func: typing.Callable):
25+
return func(x)
26+
27+
28+
@dispatch
29+
def map_at_elements(x: numpy.float32, func: typing.Callable):
30+
return func(x)
31+
32+
33+
@dispatch
34+
def map_at_elements(x: list, func: typing.Callable):
35+
return [map_at_elements(item, func) for item in x]
36+
37+
38+
@dispatch
39+
def map_at_elements(x: numpy.ndarray, func: typing.Callable):
40+
return numpy.array([map_at_elements(item, func) for item in x], dtype=object)
41+
42+
43+
@dispatch
44+
def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable):
45+
if x.ndim == 0:
46+
return func(x.item())
47+
return jax.numpy.array([map_at_elements(item, func) for item in x])
48+
49+
50+
@dispatch
51+
def map_at_elements(x: dict, func: typing.Callable):
52+
return {k: map_at_elements(v, func) for k, v in x.items()}
53+
54+
55+
@dispatch
56+
def map_at_elements(x: tuple, func: typing.Callable):
57+
return tuple(map_at_elements(list(x), func))
58+

neurallogic/real_encoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8+
89
def soft_real_encoder(t: float, x: float) -> float:
910
eps = 0.0000001
1011
# x should be in [0, 1]
@@ -44,7 +45,8 @@ class SoftRealEncoderLayer(nn.Module):
4445
@nn.compact
4546
def __call__(self, x):
4647
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
47-
thresholds = self.param("thresholds", self.thresholds_init, thresholds_shape, self.dtype)
48+
thresholds = self.param(
49+
"thresholds", self.thresholds_init, thresholds_shape, self.dtype)
4850
x = jax.numpy.asarray(x, self.dtype)
4951
return soft_real_encoder_layer(thresholds, x)
5052

@@ -55,7 +57,8 @@ class HardRealEncoderLayer(nn.Module):
5557
@nn.compact
5658
def __call__(self, x):
5759
thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real)
58-
thresholds = self.param("thresholds", nn.initializers.constant(0.0), thresholds_shape)
60+
thresholds = self.param(
61+
"thresholds", nn.initializers.constant(0.0), thresholds_shape)
5962
return hard_real_encoder_layer(thresholds, x)
6063

6164

neurallogic/symbolic_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from jax._src.util import safe_map
88
from plum import dispatch
99

10-
from neurallogic import symbolic_primitives
10+
from neurallogic import symbolic_primitives, map_at_elements
1111

1212
# Imports required for evaluating symbolic expressions with eval()
1313
import jax._src.lax_reference as lax_reference
@@ -74,7 +74,7 @@ def convert_to_numeric_params(flax_layer, param_names: str):
7474
if isinstance(actual_weights, list) or (
7575
isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object
7676
):
77-
numeric_weights = symbolic_primitives.map_at_elements(
77+
numeric_weights = map_at_elements.map_at_elements(
7878
actual_weights, lambda x: 0
7979
)
8080
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32)
@@ -87,7 +87,7 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
8787
flax_layer, thresholds = convert_to_numeric_params(flax_layer, 'thresholds')
8888
# Convert input to dummy numeric input (if needed)
8989
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
90-
x = symbolic_primitives.map_at_elements(x, lambda x: 0)
90+
x = map_at_elements.map_at_elements(x, lambda x: 0)
9191
x = numpy.asarray(x, dtype=numpy.int32)
9292
# Make the jaxpr that corresponds to the flax layer
9393
jaxpr = make_symbolic_jaxpr(flax_layer, x)

neurallogic/symbolic_operator.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import jax
2+
import numpy
3+
from plum import dispatch
4+
5+
6+
@dispatch
7+
def symbolic_operator(operator: str, x: str) -> str:
8+
return f'{operator}({x})'.replace('\'', '')
9+
10+
11+
@dispatch
12+
def symbolic_operator(operator: str, x: float, y: str):
13+
return symbolic_operator(operator, str(x), y)
14+
15+
16+
@dispatch
17+
def symbolic_operator(operator: str, x: str, y: float):
18+
return symbolic_operator(operator, x, str(y))
19+
20+
21+
@dispatch
22+
def symbolic_operator(operator: str, x: float, y: numpy.ndarray):
23+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
24+
25+
26+
@dispatch
27+
def symbolic_operator(operator: str, x: numpy.ndarray, y: numpy.ndarray):
28+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
29+
30+
31+
@dispatch
32+
def symbolic_operator(operator: str, x: str, y: str):
33+
return f'{operator}({x}, {y})'.replace('\'', '')
34+
35+
36+
@dispatch
37+
def symbolic_operator(operator: str, x: numpy.ndarray, y: float):
38+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
39+
40+
41+
@dispatch
42+
def symbolic_operator(operator: str, x: list, y: float):
43+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
44+
45+
46+
@dispatch
47+
def symbolic_operator(operator: str, x: list, y: list):
48+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
49+
50+
51+
@dispatch
52+
def symbolic_operator(operator: str, x: bool, y: str):
53+
return symbolic_operator(operator, str(x), y)
54+
55+
56+
@dispatch
57+
def symbolic_operator(operator: str, x: str, y: numpy.ndarray):
58+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
59+
60+
61+
@dispatch
62+
def symbolic_operator(operator: str, x: str, y: int):
63+
return symbolic_operator(operator, x, str(y))
64+
65+
66+
@dispatch
67+
def symbolic_operator(operator: str, x: list, y: numpy.ndarray):
68+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
69+
70+
71+
@dispatch
72+
def symbolic_operator(operator: str, x: numpy.ndarray, y: jax.numpy.ndarray):
73+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y)
74+
75+
76+
@dispatch
77+
def symbolic_operator(operator: str, x: numpy.ndarray):
78+
return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x)
79+
80+
81+
@dispatch
82+
def symbolic_operator(operator: str, x: list):
83+
return symbolic_operator(operator, numpy.array(x))
84+

0 commit comments

Comments
 (0)