Skip to content

Commit f4ee2f8

Browse files
committed
cleanup
1 parent 7d5f635 commit f4ee2f8

File tree

6 files changed

+5
-21
lines changed

6 files changed

+5
-21
lines changed

neurallogic/hard_majority.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ class SoftMajorityLayer(nn.Module):
3232
layer_size: The number of neurons in the layer.
3333
weights_init: The initializer function for the weight matrix.
3434
"""
35-
dtype: jax.numpy.dtype = jax.numpy.float32
36-
3735
@nn.compact
3836
def __call__(self, x):
39-
x = jax.numpy.asarray(x, self.dtype) # TODO: remove me?
4037
return soft_majority_layer(x)
4138

4239

@@ -57,7 +54,7 @@ def __call__(self, x):
5754

5855

5956
majority_layer = neural_logic_net.select(
60-
lambda dtype=jax.numpy.float32: SoftMajorityLayer(dtype),
61-
lambda dtype=jax.numpy.float32: HardMajorityLayer(),
62-
lambda dtype=jax.numpy.float32: SymbolicMajorityLayer()
57+
lambda: SoftMajorityLayer(),
58+
lambda: HardMajorityLayer(),
59+
lambda: SymbolicMajorityLayer()
6360
)

neurallogic/hard_not.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def soft_not(w: float, x: float) -> float:
1919

2020

2121
def hard_not(w: bool, x: bool) -> bool:
22-
# ~(x ^ w)
2322
return jax.numpy.logical_not(jax.numpy.logical_xor(x, w))
2423

2524

neurallogic/hard_or.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __call__(self, x):
6161
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
6262
weights = self.param(
6363
'bit_weights', self.weights_init, weights_shape, self.dtype)
64-
x = jax.numpy.asarray(x, self.dtype) # TODO: remove me?
64+
x = jax.numpy.asarray(x, self.dtype)
6565
return soft_or_layer(weights, x)
6666

6767

neurallogic/hard_xor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __call__(self, x):
5353
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
5454
weights = self.param(
5555
'bit_weights', self.weights_init, weights_shape, self.dtype)
56-
x = jax.numpy.asarray(x, self.dtype) # TODO is this needed?
56+
x = jax.numpy.asarray(x, self.dtype)
5757
return soft_xor_layer(weights, x)
5858

5959

neurallogic/harden.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,6 @@ def harden(x: flax.core.FrozenDict):
6262
return harden(x.unfreeze())
6363

6464

65-
"""
66-
@dispatch
67-
def harden(*args):
68-
if len(args) == 1:
69-
print(f'args = {args} of type {type(args)}')
70-
arg = args[0]
71-
print(f'args[0] = {arg}')
72-
return tuple(harden(arg))
73-
return tuple([harden(arg) for arg in args])
74-
"""
75-
7665
@dispatch
7766
def map_keys_nested(f, d: dict) -> dict:
7867
return {

neurallogic/neural_logic_net.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from enum import Enum
22
from flax import linen as nn
3-
from neurallogic import symbolic_generation
43

54
NetType = Enum('NetType', ['Soft', 'Hard', 'Symbolic'])
65

0 commit comments

Comments
 (0)