Skip to content

Commit 9a241fa

Browse files
committed
refactor dropout; noisy xor test wip
1 parent be31460 commit 9a241fa

File tree

2 files changed

+108
-86
lines changed

2 files changed

+108
-86
lines changed

neurallogic/hard_dropout.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Optional, Sequence
2+
3+
import jax
4+
from flax import linen as nn
5+
from jax import lax, random
6+
7+
from neurallogic import neural_logic_net
8+
9+
10+
class SoftHardDropout(nn.Module):
11+
"""Create a dropout layer suitable for dropping soft-bit values.
12+
Adapted from flax/stochastic.py
13+
14+
15+
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
16+
to include an RNG seed named `'dropout'`. For example::
17+
18+
model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
19+
20+
Attributes:
21+
rate: the dropout probability. (_not_ the keep rate!)
22+
broadcast_dims: dimensions that will share the same dropout mask
23+
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
24+
masked, whereas if true, no mask is applied and the inputs are returned
25+
as is.
26+
rng_collection: the rng collection name to use when requesting an rng key.
27+
"""
28+
29+
rate: float
30+
broadcast_dims: Sequence[int] = ()
31+
deterministic: Optional[bool] = None
32+
rng_collection: str = "dropout"
33+
dropout_value: float = 0.0
34+
35+
@nn.compact
36+
def __call__(self, inputs, deterministic: Optional[bool] = None):
37+
"""Applies a random dropout mask to the input.
38+
39+
Args:
40+
inputs: the inputs that should be randomly masked.
41+
Masking means setting the input bits to 0.5.
42+
deterministic: if false the inputs are masked,
43+
whereas if true, no mask is applied and the inputs are returned
44+
as is.
45+
46+
Returns:
47+
The masked inputs
48+
"""
49+
deterministic = nn.merge_param(
50+
"deterministic", self.deterministic, deterministic
51+
)
52+
53+
if (self.rate == 0.0) or deterministic:
54+
return inputs
55+
56+
# Prevent gradient NaNs in 1.0 edge-case.
57+
if self.rate == 1.0:
58+
return jax.numpy.zeros_like(inputs)
59+
60+
keep_prob = 1.0 - self.rate
61+
rng = self.make_rng(self.rng_collection)
62+
broadcast_shape = list(inputs.shape)
63+
for dim in self.broadcast_dims:
64+
broadcast_shape[dim] = 1
65+
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
66+
mask = jax.numpy.broadcast_to(mask, inputs.shape)
67+
masked_values = jax.numpy.full_like(inputs, self.dropout_value, dtype=float)
68+
return lax.select(mask, inputs, masked_values)
69+
70+
71+
class HardHardDropout(nn.Module):
72+
@nn.compact
73+
def __call__(self, inputs, deterministic: Optional[bool] = None):
74+
return inputs
75+
76+
77+
class SymbolicHardDropout(nn.Module):
78+
@nn.compact
79+
def __call__(self, inputs, deterministic: Optional[bool] = None):
80+
return inputs
81+
82+
83+
hard_dropout = neural_logic_net.select(
84+
lambda **kwargs: SoftHardDropout(**kwargs),
85+
lambda **kwargs: HardHardDropout(**kwargs),
86+
lambda **kwargs: SymbolicHardDropout(**kwargs),
87+
)

tests/test_noisy_xor.py

Lines changed: 21 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from pathlib import Path
2-
from typing import Optional, Sequence
32

43
import jax
5-
import jax.numpy as jnp
64
import ml_collections
75
import numpy
86
import optax
9-
from flax import linen as nn
107
from flax.training import train_state
11-
from jax import lax, random
128
from tqdm import tqdm
139

1410
from neurallogic import (
@@ -17,6 +13,7 @@
1713
hard_not,
1814
hard_or,
1915
hard_xor,
16+
hard_dropout,
2017
harden,
2118
harden_layer,
2219
neural_logic_net,
@@ -60,71 +57,6 @@ def check_symbolic(nets, data, trained_state):
6057
print("symbolic_output", symbolic_output[0][:10000])
6158

6259

63-
class BinaryDropout(nn.Module):
64-
"""Create a dropout layer.
65-
66-
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
67-
to include an RNG seed named `'dropout'`. For example::
68-
69-
model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
70-
71-
Attributes:
72-
rate: the dropout probability. (_not_ the keep rate!)
73-
broadcast_dims: dimensions that will share the same dropout mask
74-
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
75-
masked, whereas if true, no mask is applied and the inputs are returned
76-
as is.
77-
rng_collection: the rng collection name to use when requesting an rng key.
78-
"""
79-
80-
rate: float
81-
broadcast_dims: Sequence[int] = ()
82-
deterministic: Optional[bool] = None
83-
rng_collection: str = "dropout"
84-
85-
@nn.compact
86-
def __call__(self, inputs, deterministic: Optional[bool] = None):
87-
"""Applies a random dropout mask to the input.
88-
89-
Args:
90-
inputs: the inputs that should be randomly masked.
91-
Masking means setting the input bits to 0.5.
92-
deterministic: if false the inputs are masked,
93-
whereas if true, no mask is applied and the inputs are returned
94-
as is.
95-
96-
Returns:
97-
The masked inputs
98-
"""
99-
deterministic = nn.merge_param(
100-
"deterministic", self.deterministic, deterministic
101-
)
102-
103-
if (self.rate == 0.0) or deterministic:
104-
return inputs
105-
106-
# Prevent gradient NaNs in 1.0 edge-case.
107-
if self.rate == 1.0:
108-
return jnp.zeros_like(inputs)
109-
110-
keep_prob = 1.0 - self.rate
111-
rng = self.make_rng(self.rng_collection)
112-
broadcast_shape = list(inputs.shape)
113-
for dim in self.broadcast_dims:
114-
broadcast_shape[dim] = 1
115-
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
116-
mask = jnp.broadcast_to(mask, inputs.shape)
117-
# masked_values = jnp.ones_like(inputs, dtype=float) / 2.0
118-
masked_values = jnp.zeros_like(inputs, dtype=float)
119-
# masked_values = jnp.ones_like(inputs, dtype=float)
120-
# print(f"mask {mask.shape} {mask.dtype} {mask}")
121-
# print(
122-
# f"masked_values {masked_values.shape} {masked_values.dtype} {masked_values}"
123-
# )
124-
# print(f"inputs {inputs.shape} {inputs.dtype} {inputs}")
125-
return lax.select(mask, inputs, masked_values)
126-
127-
12860
num_features = 12
12961
num_classes = 2
13062

@@ -144,7 +76,7 @@ def get_data():
14476

14577

14678
# 100% test accuracy
147-
def nln_100(type, x):
79+
def nln(type, x, training: bool):
14880
x = hard_and.and_layer(type)(20)(x)
14981
x = hard_not.not_layer(type)(4)(x)
15082
x = x.ravel()
@@ -155,12 +87,15 @@ def nln_100(type, x):
15587
return x
15688

15789

158-
# 100% test accuracy
159-
def nln(type, x, training: bool):
160-
x = hard_xor.xor_layer(type)(40)(x)
161-
# x = hard_not.not_layer(type)(1)(x)
162-
x = BinaryDropout(rate=0.5, deterministic=not training)(x)
163-
# x = x.ravel()
90+
def nln_experimental(type, x, training: bool):
91+
not_x = jax.numpy.logical_not(x)
92+
input = jax.numpy.concatenate([x, not_x], axis=0)
93+
x = hard_xor.xor_layer(type)(100)(input)
94+
# x = hard_not.not_layer(type)(4)(x)
95+
x = x.ravel()
96+
x = hard_dropout.hard_dropout(type)(
97+
rate=0.5, dropout_value=0.0, deterministic=not training
98+
)(x)
16499
########################################################
165100
x = harden_layer.harden_layer(type)(x)
166101
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
@@ -280,18 +215,18 @@ def train_and_evaluate(
280215
)
281216
if train_accuracy > best_train_accuracy:
282217
best_train_accuracy = train_accuracy
283-
print(f"best_train_accuracy: {best_train_accuracy * 100:.2f}")
218+
# print(f"best_train_accuracy: {best_train_accuracy * 100:.2f}")
284219
if test_accuracy >= best_test_accuracy:
285220
best_test_accuracy = test_accuracy
286-
print(f"best_test_accuracy: {best_test_accuracy * 100:.2f}")
287-
else:
288-
print(f"test_accuracy: {test_accuracy * 100:.2f}")
289-
print("\n")
221+
# print(f"best_test_accuracy: {best_test_accuracy * 100:.2f}")
222+
# else:
223+
# print(f"test_accuracy: {test_accuracy * 100:.2f}")
224+
# print("\n")
290225

291-
print(
292-
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
293-
% (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
294-
)
226+
# print(
227+
# "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
228+
# % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
229+
# )
295230

296231
return state
297232

@@ -320,7 +255,7 @@ def get_config():
320255
config.learning_rate = 0.01
321256
config.momentum = 0.9
322257
config.batch_size = 256
323-
config.num_epochs = 1000
258+
config.num_epochs = 500
324259
return config
325260

326261

0 commit comments

Comments
 (0)