gssl.classifiers.nn package¶
Submodules¶
gssl.classifiers.nn.NN module¶
A Neural Network approach I was previously using for some experiments. Ignore this for now.
-
class
gssl.classifiers.nn.NN.Accumulator(tensor, name, init_val=0.0)¶ Bases:
object-
__init__(tensor, name, init_val=0.0)¶ Initialize self. See help(type(self)) for accurate signature.
-
-
class
gssl.classifiers.nn.NN.NNClassifier(image_shape=None, OUT_SIZE=10, NUM_EPOCHS=2000, BATCH_SIZE=50, AUGMENT=False, model_choice='simple', LEARNING_RATE=0.0001)¶ Bases:
gssl.classifiers.classifier.GSSLClassifierA NN classifier that was intended to optimize a labeled and unlabeled objective at the same time. Please ignore this for the time being
-
ALPHA= 0.1¶
-
LAMBDA= 0.5¶
-
RECALC_W= False¶
-
SIGMA= <tf.Tensor: id=0, shape=(), dtype=float32, numpy=0.04>¶
-
USE_UNLABELED= True¶
-
__init__(image_shape=None, OUT_SIZE=10, NUM_EPOCHS=2000, BATCH_SIZE=50, AUGMENT=False, model_choice='simple', LEARNING_RATE=0.0001)¶ Constructor for NN classifier.
Args:
-
build_graph(X, k)¶
-
eval_get_data= None¶
-
evaluate_simfunc(W_sparse_vals)¶
-
fit(X, W, Y, labeledIndexes, hook=None, Y_true=None)¶ Classifies the input data.
- Parameters
X (NDArray[float].shape[N,D]) – Input matrix of N instances of dimension D.
W (NDArray[float].shape[N,N]) – The affinity matrix encoding the weighted edges.
Y (NDArray[float].shape[N,C]) – The initial belief matrix
hook (GSSLHook) – Optional. A hook to execute extra operations (e.g. plots) during the algorithm
- Returns
An updated belief matrix.
- Return type
NDArray[float].shape[N,C]
-
labeled_gen()¶
-
pred_gen()¶
-
random_gen()¶
-
unlabeled_gen()¶
-
unlabeled_pairs_gen()¶
-
-
gssl.classifiers.nn.NN.convert_sparse_matrix_to_sparse_tensor(X, var_values=False)¶
-
gssl.classifiers.nn.NN.cos_decay(init_val, EPOCH_VAR, rampdown_length)¶
-
gssl.classifiers.nn.NN.debug(msg)¶
-
gssl.classifiers.nn.NN.ent(Y)¶
-
gssl.classifiers.nn.NN.gather(x, F)¶
-
gssl.classifiers.nn.NN.get_S(W)¶
-
gssl.classifiers.nn.NN.get_S_fromtensor(W)¶
-
gssl.classifiers.nn.NN.kl_divergence(self, p, q)¶
-
gssl.classifiers.nn.NN.repeat(x, n)¶
-
gssl.classifiers.nn.NN.row_normalize(x)¶
-
gssl.classifiers.nn.NN.xent(distr_1, distr_2)¶
gssl.classifiers.nn.models module¶
Created on 4 de out de 2019
@author: klaus
-
gssl.classifiers.nn.models.conv_large(input_shape, output_shape)¶
-
gssl.classifiers.nn.models.conv_small(input_shape, output_shape)¶
-
gssl.classifiers.nn.models.linear(input_shape, output_shape)¶
-
gssl.classifiers.nn.models.simple(input_shape, output_shape)¶