1
1
from pathlib import Path
2
- from typing import Optional , Sequence
3
2
4
3
import jax
5
- import jax .numpy as jnp
6
4
import ml_collections
7
5
import numpy
8
6
import optax
9
- from flax import linen as nn
10
7
from flax .training import train_state
11
- from jax import lax , random
12
8
from tqdm import tqdm
13
9
14
10
from neurallogic import (
17
13
hard_not ,
18
14
hard_or ,
19
15
hard_xor ,
16
+ hard_dropout ,
20
17
harden ,
21
18
harden_layer ,
22
19
neural_logic_net ,
@@ -60,71 +57,6 @@ def check_symbolic(nets, data, trained_state):
60
57
print ("symbolic_output" , symbolic_output [0 ][:10000 ])
61
58
62
59
63
- class BinaryDropout (nn .Module ):
64
- """Create a dropout layer.
65
-
66
- Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
67
- to include an RNG seed named `'dropout'`. For example::
68
-
69
- model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
70
-
71
- Attributes:
72
- rate: the dropout probability. (_not_ the keep rate!)
73
- broadcast_dims: dimensions that will share the same dropout mask
74
- deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
75
- masked, whereas if true, no mask is applied and the inputs are returned
76
- as is.
77
- rng_collection: the rng collection name to use when requesting an rng key.
78
- """
79
-
80
- rate : float
81
- broadcast_dims : Sequence [int ] = ()
82
- deterministic : Optional [bool ] = None
83
- rng_collection : str = "dropout"
84
-
85
- @nn .compact
86
- def __call__ (self , inputs , deterministic : Optional [bool ] = None ):
87
- """Applies a random dropout mask to the input.
88
-
89
- Args:
90
- inputs: the inputs that should be randomly masked.
91
- Masking means setting the input bits to 0.5.
92
- deterministic: if false the inputs are masked,
93
- whereas if true, no mask is applied and the inputs are returned
94
- as is.
95
-
96
- Returns:
97
- The masked inputs
98
- """
99
- deterministic = nn .merge_param (
100
- "deterministic" , self .deterministic , deterministic
101
- )
102
-
103
- if (self .rate == 0.0 ) or deterministic :
104
- return inputs
105
-
106
- # Prevent gradient NaNs in 1.0 edge-case.
107
- if self .rate == 1.0 :
108
- return jnp .zeros_like (inputs )
109
-
110
- keep_prob = 1.0 - self .rate
111
- rng = self .make_rng (self .rng_collection )
112
- broadcast_shape = list (inputs .shape )
113
- for dim in self .broadcast_dims :
114
- broadcast_shape [dim ] = 1
115
- mask = random .bernoulli (rng , p = keep_prob , shape = broadcast_shape )
116
- mask = jnp .broadcast_to (mask , inputs .shape )
117
- # masked_values = jnp.ones_like(inputs, dtype=float) / 2.0
118
- masked_values = jnp .zeros_like (inputs , dtype = float )
119
- # masked_values = jnp.ones_like(inputs, dtype=float)
120
- # print(f"mask {mask.shape} {mask.dtype} {mask}")
121
- # print(
122
- # f"masked_values {masked_values.shape} {masked_values.dtype} {masked_values}"
123
- # )
124
- # print(f"inputs {inputs.shape} {inputs.dtype} {inputs}")
125
- return lax .select (mask , inputs , masked_values )
126
-
127
-
128
60
num_features = 12
129
61
num_classes = 2
130
62
@@ -144,7 +76,7 @@ def get_data():
144
76
145
77
146
78
# 100% test accuracy
147
- def nln_100 (type , x ):
79
+ def nln (type , x , training : bool ):
148
80
x = hard_and .and_layer (type )(20 )(x )
149
81
x = hard_not .not_layer (type )(4 )(x )
150
82
x = x .ravel ()
@@ -155,12 +87,15 @@ def nln_100(type, x):
155
87
return x
156
88
157
89
158
- # 100% test accuracy
159
- def nln (type , x , training : bool ):
160
- x = hard_xor .xor_layer (type )(40 )(x )
161
- # x = hard_not.not_layer(type)(1)(x)
162
- x = BinaryDropout (rate = 0.5 , deterministic = not training )(x )
163
- # x = x.ravel()
90
+ def nln_experimental (type , x , training : bool ):
91
+ not_x = jax .numpy .logical_not (x )
92
+ input = jax .numpy .concatenate ([x , not_x ], axis = 0 )
93
+ x = hard_xor .xor_layer (type )(100 )(input )
94
+ # x = hard_not.not_layer(type)(4)(x)
95
+ x = x .ravel ()
96
+ x = hard_dropout .hard_dropout (type )(
97
+ rate = 0.5 , dropout_value = 0.0 , deterministic = not training
98
+ )(x )
164
99
########################################################
165
100
x = harden_layer .harden_layer (type )(x )
166
101
x = x .reshape ((num_classes , int (x .shape [0 ] / num_classes )))
@@ -280,18 +215,18 @@ def train_and_evaluate(
280
215
)
281
216
if train_accuracy > best_train_accuracy :
282
217
best_train_accuracy = train_accuracy
283
- print (f"best_train_accuracy: { best_train_accuracy * 100 :.2f} " )
218
+ # print(f"best_train_accuracy: {best_train_accuracy * 100:.2f}")
284
219
if test_accuracy >= best_test_accuracy :
285
220
best_test_accuracy = test_accuracy
286
- print (f"best_test_accuracy: { best_test_accuracy * 100 :.2f} " )
287
- else :
288
- print (f"test_accuracy: { test_accuracy * 100 :.2f} " )
289
- print ("\n " )
221
+ # print(f"best_test_accuracy: {best_test_accuracy * 100:.2f}")
222
+ # else:
223
+ # print(f"test_accuracy: {test_accuracy * 100:.2f}")
224
+ # print("\n")
290
225
291
- print (
292
- "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
293
- % (epoch , train_loss , train_accuracy * 100 , test_loss , test_accuracy * 100 )
294
- )
226
+ # print(
227
+ # "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
228
+ # % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
229
+ # )
295
230
296
231
return state
297
232
@@ -320,7 +255,7 @@ def get_config():
320
255
config .learning_rate = 0.01
321
256
config .momentum = 0.9
322
257
config .batch_size = 256
323
- config .num_epochs = 1000
258
+ config .num_epochs = 500
324
259
return config
325
260
326
261
0 commit comments