diff --git a/Naive Bayes Classifier/NaiveBayes.swift b/Naive Bayes Classifier/NaiveBayes.swift index 46a0bb4f5..20bf86c24 100644 --- a/Naive Bayes Classifier/NaiveBayes.swift +++ b/Naive Bayes Classifier/NaiveBayes.swift @@ -39,7 +39,7 @@ enum NBType { case gaussian case multinomial - //case bernoulli --> TODO + case bernoulli func calcLikelihood(variables: [Any], input: Any) -> Double? { @@ -76,6 +76,22 @@ enum NBType { return variable.category == input }?.probability + } else if case .bernoulli = self { + + guard let variables = variables as? [(category: Int, probability: Double)] else { + return nil + } + + guard let input = input as? Bool else { + return nil + } + + let probability = variables.first { variable in + return variable.category == (input ? 1 : 0) + }?.probability ?? 0.0 + + return probability + } return nil @@ -102,6 +118,17 @@ enum NBType { return (value, Double(values.filter { $0 == value }.count) / Double(count)) } return categoryProba + + } else if case .bernoulli = self { + + guard let values = values as? [Bool] else { + return nil + } + + let count = values.count + let categoryProba = [(0, 1.0 - Double(values.filter { $0 }.count) / Double(count)), + (1, Double(values.filter { $0 }.count) / Double(count))] + return categoryProba } return nil @@ -126,6 +153,8 @@ class NaiveBayes { throw "When using Gaussian NB you have to have continuous features (Double)" } else if case .multinomial = type, T.self != Int.self { throw "When using Multinomial NB you have to have categorical features (Int)" + } else if case .bernoulli = type, T.self != Double.self { + throw "When using Bernoulli NB you have to have continuous features (Double)" } }