Skip to content

Commit b382177

Browse files
committed
cleanup
1 parent 2388a8f commit b382177

File tree

3 files changed

+1
-87
lines changed

3 files changed

+1
-87
lines changed

neurallogic/harden.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,6 @@
88
def harden_float(x: float) -> bool:
99
return x > 0.5
1010

11-
"""
12-
harden_array = jax.vmap(harden_float, 0, 0)
13-
14-
def harden_dict(x: dict) -> dict:
15-
return {k: harden(v) for k, v in x.items()}
16-
17-
@dispatch
18-
def harden(x: float):
19-
return harden_float(x)
20-
21-
@dispatch
22-
def harden(x: list):
23-
return [harden(xi) for xi in x]
24-
25-
@dispatch
26-
def harden(x: flax.core.frozen_dict.FrozenDict):
27-
return harden_dict(x)
28-
29-
@dispatch
30-
def harden(x: dict):
31-
return harden_dict(x)
32-
33-
@dispatch
34-
def harden(x: numpy.ndarray):
35-
if x.shape != ():
36-
return harden_array(x)
37-
else:
38-
return numpy.array(harden(x.item()))
39-
40-
@dispatch
41-
def harden(x: jax.numpy.ndarray):
42-
if x.shape != ():
43-
return harden_array(x)
44-
else:
45-
return numpy.array(harden(x.item()))
46-
47-
@dispatch
48-
def harden(*args):
49-
#print(f'harden: {args} of type {type(args)} with length {len(args)}')
50-
#print(f'type of elements are {[type(arg) for arg in args]}')
51-
#if len(args) == 1:
52-
# return harden(args[0])
53-
return tuple([harden(arg) for arg in args])
54-
"""
55-
5611
@dispatch
5712
def harden(x: float):
5813
return harden_float(x)

neurallogic/sym_gen.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -139,56 +139,33 @@ def eval_jaxpr_impl(jaxpr):
139139
@dispatch
140140
def make_symbolic(x: dict):
141141
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
142-
@dispatch
143-
def make_numeric(x: dict):
144-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_numeric_value)
145-
146142

147143
@dispatch
148144
def make_symbolic(x: list):
149145
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
150-
@dispatch
151-
def make_numeric(x: list):
152-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_numeric_value)
153-
154146

155147
@dispatch
156148
def make_symbolic(x: numpy.ndarray):
157149
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
158-
@dispatch
159-
def make_numeric(x: numpy.ndarray):
160-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_numeric_value)
161-
162150

163151
@dispatch
164152
def make_symbolic(x: jax.numpy.ndarray):
165153
return symbolic_primitives.map_at_elements(convert_jax_to_numpy_arrays(x), symbolic_primitives.to_boolean_value_string)
166-
@dispatch
167-
def make_numeric(x: jax.numpy.ndarray):
168-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_numeric_value)
169154

170155

171156
@dispatch
172157
def make_symbolic(x: bool):
173158
return symbolic_primitives.to_boolean_value_string(x)
174-
@dispatch
175-
def make_numeric(x: bool):
176-
return symbolic_primitives.to_numeric_value(x)
177159

178160

179161
@dispatch
180162
def make_symbolic(x: str):
181163
return symbolic_primitives.to_boolean_value_string(x)
182-
@dispatch
183-
def make_numeric(x: str):
184-
return symbolic_primitives.to_numeric_value(x)
164+
185165

186166
@dispatch
187167
def make_symbolic(*args):
188168
return tuple([make_symbolic(arg) for arg in args])
189-
@dispatch
190-
def make_numeric(*args):
191-
return tuple([make_numeric(arg) for arg in args])
192169

193170

194171
@dispatch
@@ -205,10 +182,6 @@ def convert_jax_to_numpy_arrays(x: dict):
205182
def make_symbolic(x: flax.core.FrozenDict):
206183
x = convert_jax_to_numpy_arrays(x.unfreeze())
207184
return flax.core.FrozenDict(make_symbolic(x))
208-
@dispatch
209-
def make_numeric(x: flax.core.FrozenDict):
210-
x = x.unfreeze()
211-
return flax.core.FrozenDict(make_numeric(x))
212185

213186

214187
@dispatch

neurallogic/symbolic_primitives.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,6 @@ def convert(x):
2222
return map_at_elements(x, convert)
2323

2424

25-
# TODO: remove me
26-
def convert_iterable_type(x: list, new_type):
27-
if new_type == list:
28-
return x
29-
elif new_type == numpy.ndarray:
30-
return numpy.array(x, dtype=object)
31-
elif new_type == jax.numpy.ndarray:
32-
return jax.numpy.array(x, dtype=object)
33-
elif new_type == jaxlib.xla_extension.DeviceArray:
34-
return jax.numpy.array(x, dtype=object)
35-
else:
36-
raise NotImplementedError(
37-
f"Cannot convert type {type(x)} to type {new_type}")
38-
3925
# TODO: allow func callable to control the type of the numpy.array or jax.numpy.array
4026

4127
# map_at_elements should alter the elements but not the type of the container

0 commit comments

Comments
 (0)