Skip to content

Commit f8da478

Browse files
committed
refactor some utils
1 parent 4beee42 commit f8da478

12 files changed

+113
-208
lines changed

neurallogic/symbolic_generation.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import sys
21
import typing
32
from typing import Any, Mapping
43

5-
import flax
64
import jax
75
import numpy
86
from jax import core
@@ -209,70 +207,6 @@ def eval_jaxpr_impl(jaxpr):
209207
return safe_map(symbolic_read, jaxpr.outvars)[0]
210208

211209

212-
def to_string(x):
213-
return str(x)
214-
215-
# TODO: use union types to consolidate these functions
216-
@dispatch
217-
def make_symbolic(x: dict):
218-
return symbolic_primitives.map_at_elements(
219-
x, to_string
220-
)
221-
222-
223-
@dispatch
224-
def make_symbolic(x: list):
225-
return symbolic_primitives.map_at_elements(
226-
x, to_string
227-
)
228-
229-
230-
@dispatch
231-
def make_symbolic(x: numpy.ndarray):
232-
return symbolic_primitives.map_at_elements(
233-
x, to_string
234-
)
235-
236-
237-
@dispatch
238-
def make_symbolic(x: jax.numpy.ndarray):
239-
return symbolic_primitives.map_at_elements(
240-
convert_jax_to_numpy_arrays(x), to_string
241-
)
242-
243-
244-
@dispatch
245-
def make_symbolic(x: bool):
246-
return to_string(x)
247-
248-
249-
@dispatch
250-
def make_symbolic(x: str):
251-
return to_string(x)
252-
253-
254-
@dispatch
255-
def convert_jax_to_numpy_arrays(x: jax.numpy.ndarray):
256-
return numpy.asarray(x)
257-
258-
259-
@dispatch
260-
def convert_jax_to_numpy_arrays(x: dict):
261-
return {k: convert_jax_to_numpy_arrays(v) for k, v in x.items()}
262-
263-
264-
@dispatch
265-
def make_symbolic(x: flax.core.FrozenDict):
266-
x = convert_jax_to_numpy_arrays(x.unfreeze())
267-
return flax.core.FrozenDict(make_symbolic(x))
268-
269-
270-
@dispatch
271-
def make_symbolic(*args):
272-
return tuple([make_symbolic(arg) for arg in args])
273-
274-
275-
@dispatch
276210
def make_symbolic_jaxpr(func: typing.Callable, *args):
277211
return jax.make_jaxpr(lambda *args: func(*args))(*args)
278212

tests/test_hard_and.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import jax
2+
import numpy
23
import optax
34
from flax import linen as nn
45
from flax.training import train_state
56
from jax import random
6-
import numpy
77

88
from neurallogic import hard_and, harden, neural_logic_net, symbolic_generation
99
from tests import utils
@@ -185,7 +185,7 @@ def test_net(type, x):
185185

186186
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
187187
symbolic_input = ['True', 'False']
188-
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
188+
symbolic_weights = utils.make_symbolic(hard_weights)
189189
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
190190
symbolic_output = symbolic_generation.eval_symbolic_expression(
191191
symbolic_output)

tests/test_hard_not.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import jax
22
import jax.numpy as jnp
3+
import numpy
34
import optax
45
from flax.training import train_state
56
from jax import random
6-
import numpy
77

88
from neurallogic import hard_not, harden, neural_logic_net, symbolic_generation
99
from tests import utils
@@ -250,7 +250,7 @@ def test_net(type, x):
250250

251251
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
252252
symbolic_input = ["True", "False"]
253-
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
253+
symbolic_weights = utils.make_symbolic(hard_weights)
254254
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
255255
symbolic_output = symbolic_generation.eval_symbolic_expression(
256256
symbolic_output)

tests/test_hard_or.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import jax
22
import jax.numpy as jnp
3+
import numpy
34
import optax
45
from flax import linen as nn
56
from flax.training import train_state
67
from jax import random
7-
import numpy
88

99
from neurallogic import hard_or, harden, neural_logic_net, symbolic_generation
1010
from tests import utils
@@ -184,7 +184,7 @@ def test_net(type, x):
184184

185185
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
186186
symbolic_input = ['True', 'False']
187-
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
187+
symbolic_weights = utils.make_symbolic(hard_weights)
188188
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
189189
symbolic_output = symbolic_generation.eval_symbolic_expression(
190190
symbolic_output)

tests/test_harden.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import flax
12
import jax.numpy as jnp
3+
24
from neurallogic import harden
3-
import flax
45

56

67
def test_harden_float():

tests/test_harden_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import jax.numpy as jnp
22

3-
from neurallogic import harden_layer, harden
3+
from neurallogic import harden, harden_layer
44

55

66
def test_harden_layer():

tests/test_mnist.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from tqdm import tqdm
2-
from matplotlib import pyplot as plt
3-
import tensorflow as tf
4-
import tensorflow_datasets as tfds
51
import jax
62
import jax.numpy as jnp
3+
import ml_collections
74
import numpy as np
5+
import optax
6+
import tensorflow as tf
7+
import tensorflow_datasets as tfds
88
from flax import linen as nn
99
from flax.metrics import tensorboard
1010
from flax.training import train_state
11-
import ml_collections
12-
from neurallogic import hard_not, hard_or, harden, harden_layer, neural_logic_net
13-
import optax
11+
from matplotlib import pyplot as plt
12+
from tqdm import tqdm
1413

14+
from neurallogic import (hard_not, hard_or, harden, harden_layer,
15+
neural_logic_net)
1516

1617
"""
1718
MNIST test.

tests/test_network.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from neurallogic import (hard_and, hard_not, hard_or, harden, harden_layer,
1111
neural_logic_net, symbolic_generation)
12+
from tests import utils
1213

1314
config.update("jax_enable_x64", True)
1415

@@ -70,7 +71,7 @@ def test_net(type, x):
7071
hard_result = hard.apply(hard_weights, hard_input)
7172
assert jnp.array_equal(hard_result, hard_expected)
7273

73-
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
74+
symbolic_weights = utils.make_symbolic(hard_weights)
7475
symbolic_result = symbolic.apply(symbolic_weights, hard_input)
7576
symbolic_result = symbolic_generation.eval_symbolic_expression(
7677
symbolic_result)

tests/test_real_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from neurallogic import (harden, neural_logic_net, real_encoder,
1111
symbolic_generation)
12+
from tests import utils
1213

1314
# Uncomment to debug NaNs
1415
# config.update("jax_debug_nans", True)
@@ -249,7 +250,7 @@ def test_net(type, x):
249250

250251
# Compute symbolic result with symbolic inputs and symbolic weights, but where the symbols can be evaluated
251252
symbolic_input = ['1.0', '0.0']
252-
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
253+
symbolic_weights = utils.make_symbolic(hard_weights)
253254
symbolic_output = symbolic.apply(symbolic_weights, symbolic_input)
254255
symbolic_output = symbolic_generation.eval_symbolic_expression(
255256
symbolic_output)

tests/test_symbolic_generation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_not, symbolic_generation, symbolic_primitives
2-
from tests import test_mnist
3-
import numpy
41
import jax
52
import jax.numpy as jnp
3+
import numpy
4+
5+
from neurallogic import (hard_not, hard_or, harden, harden_layer,
6+
neural_logic_net, symbolic_generation,
7+
symbolic_primitives)
8+
from tests import test_mnist, utils
69

710

811
def nln(type, x, width):
@@ -45,7 +48,7 @@ def test_symbolic_generation():
4548
assert numpy.array_equal(symbolic_output, hard_output)
4649

4750
# Check the standard evaluation of the network equals the non-standard symbolic evaluation
48-
symbolic_mock_input = symbolic_generation.make_symbolic(hard_mock_input)
51+
symbolic_mock_input = utils.make_symbolic(hard_mock_input)
4952
symbolic_output = symbolic.apply(hard_weights, symbolic_mock_input)
5053
assert numpy.array_equal(hard_output.shape, symbolic_output.shape)
5154

0 commit comments

Comments
 (0)