Skip to content

Commit 63a167b

Browse files
committed
experiment with binary dropout
1 parent c81c44e commit 63a167b

File tree

3 files changed

+244
-55
lines changed

3 files changed

+244
-55
lines changed

neurallogic/neural_logic_net.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
from enum import Enum
22
from flax import linen as nn
33

4-
NetType = Enum('NetType', ['Soft', 'Hard', 'Symbolic'])
4+
NetType = Enum("NetType", ["Soft", "Hard", "Symbolic"])
5+
56

67
def select(soft, hard, symbolic):
78
def selector(type: NetType):
8-
return {
9-
NetType.Soft: soft,
10-
NetType.Hard: hard,
11-
NetType.Symbolic: symbolic
12-
}[type]
9+
return {NetType.Soft: soft, NetType.Hard: hard, NetType.Symbolic: symbolic}[
10+
type
11+
]
12+
1313
return selector
1414

15+
1516
def net(f):
1617
class SoftNet(nn.Module):
1718
@nn.compact
18-
def __call__(self, x):
19-
return f(NetType.Soft, x)
20-
class HardNet(nn.Module):
19+
def __call__(self, x, **kwargs):
20+
return f(NetType.Soft, x, **kwargs)
21+
22+
class HardNet(nn.Module):
2123
@nn.compact
2224
def __call__(self, x):
2325
return f(NetType.Hard, x)
26+
2427
class SymbolicNet(nn.Module):
2528
@nn.compact
2629
def __call__(self, x):
2730
return f(NetType.Symbolic, x)
28-
return SoftNet(), HardNet(), SymbolicNet()
2931

32+
return SoftNet(), HardNet(), SymbolicNet()

tests/test_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,5 +325,5 @@ def test_mnist():
325325
)
326326

327327
# Check symbolic net
328-
# _, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
329-
# check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)
328+
_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
329+
check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)

0 commit comments

Comments
 (0)