Skip to content

Commit 4c3173a

Browse files
committed
maintain test_harden
1 parent a786e05 commit 4c3173a

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

neurallogic/harden.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def harden_float(x: float) -> bool:
1212
def harden(x: float):
1313
return harden_float(x)
1414

15+
@dispatch
16+
def harden(x: list):
17+
return symbolic_primitives.map_at_elements(x, harden_float)
18+
1519
@dispatch
1620
def harden(x: numpy.ndarray):
1721
return symbolic_primitives.map_at_elements(x, harden_float)
@@ -24,6 +28,10 @@ def harden(x: jax.numpy.ndarray):
2428
def harden(x: dict):
2529
return symbolic_primitives.map_at_elements(x, harden_float)
2630

31+
@dispatch
32+
def harden(x: flax.core.FrozenDict):
33+
return flax.core.FrozenDict(symbolic_primitives.map_at_elements(x.unfreeze(), harden_float))
34+
2735
@dispatch
2836
def harden(*args):
2937
if len(args) == 1:
@@ -34,9 +42,5 @@ def harden(*args):
3442
def map_keys_nested(f, d: dict) -> dict:
3543
return {f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items()}
3644

37-
3845
def hard_weights(weights):
39-
unfrozen_weights = weights.unfreeze()
40-
hard_weights = harden(unfrozen_weights)
41-
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Hard"), hard_weights))
42-
46+
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze())))

tests/test_harden.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@ def test_harden_float():
1212

1313

1414
def test_harden_list():
15-
assert harden.harden_list([0.5, 0.6, 0.4, 0.0, 1.0]) == [
15+
assert harden.harden([0.5, 0.6, 0.4, 0.0, 1.0]) == [
1616
False, True, False, False, True]
1717

1818

1919
def test_harden_array():
20-
assert jnp.array_equal(harden.harden_array(
20+
assert jnp.array_equal(harden.harden(
2121
jnp.array([0.5, 0.6, 0.4, 0.0, 1.0])), [False, True, False, False, True])
2222

2323

2424
def test_harden_dict():
2525
dict = {'a': 0.5, 'b': 0.6, 'c': 0.4, 'd': 0.0, 'e': 1.0}
2626
expected_dict = {'a': False, 'b': True, 'c': False, 'd': False, 'e': True}
27-
assert harden.harden_dict(dict) == expected_dict
27+
assert harden.harden(dict) == expected_dict
2828

2929

3030
def test_harden_frozen_dict():
3131
dict = flax.core.frozen_dict.FrozenDict(
3232
{'a': 0.5, 'b': 0.6, 'c': 0.4, 'd': 0.0, 'e': 1.0})
3333
expected_dict = {'a': False, 'b': True, 'c': False, 'd': False, 'e': True}
34-
assert harden.harden_dict(dict) == expected_dict
34+
assert harden.harden(dict) == expected_dict
3535

3636

3737
def test_harden():

0 commit comments

Comments
 (0)