Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions scripts/builtin/detectMissingType.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------


# INPUT:
# ------------------------------------------------------------------------------------
# X Matrix[Double] (n x p)
# Data matrix with possible missing values (NA)
# alpha Double
# Significance level for MCAR chi-square tests (default = 0.05)
# auc_thresh Double
# Threshold on AUC to distinguish MAR vs MNAR (default = 0.70)
# ------------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------------
# DetectedCodes Matrix[Double] (p x 1)
# Missingness type detected for each column:
# 0 = no missing values
# 1 = MCAR
# 2 = MAR
# 3 = MNAR
# ------------------------------------------------------------------------------------
#
# The following function analyzes every column in the dataset to identify the missingness' mechanism
# For each column having missing values, at first it verifies if the missing value depends on other variables (MCAR test via chi-squared)
# If MCAR gets excluded, the function considers if the missing can be explain by the other observed columns using a logistic regression
# and AUc instead. In this way we differentiate between MAR and MNAR
# In the end, it returns a code per column indicating the type of missingness

m_detectMissingType = function(Matrix[Double] X, Double alpha = 0.05, Double auc_thresh = 0.70)
return (Matrix[Double] DetectedCodes) {

p = ncol(X);
DetectedCodes = matrix(0, rows=p, cols=1);
min_rows = 20;

print("Col\tMissing\tMin_PVal\tAUC\tPredizione");
print("---------------------------------------------------------");

for (i in 1:p) {
target_col = X[, i];
missing_count = sum(is.na(target_col));

detected_code = 0;
label = "NONE";
best_pval = 1.0;
auc_val = 0.5;

if (missing_count > 0) {
miss_flag = is.na(target_col) + 1;

# MCAR check
for (j in 1:p) {
if (i != j) {
col_j = X[, j];
observed = (is.na(col_j) == 0);

if (sum(observed) > min_rows) {
R_v = removeEmpty(target=miss_flag, margin="rows", select=observed);
P_v = removeEmpty(target=col_j, margin="rows", select=observed);

min_v = min(P_v);
max_v = max(P_v);
range_v = max_v - min_v;

if (range_v > 0) {
P_bin = bin_into_5(P_v);
F = table(R_v, P_bin);

if (nrow(F) > 1 & ncol(F) > 1) {
pval = chisq_pval(F);
if (pval < best_pval) { best_pval = pval; }
}
}
}
}
}

if (best_pval > alpha) {
label = "MCAR";
detected_code = 1;
} else {
# MAR vs MNAR check
features = drop_column(X, i);
auc_val = mar_auc_score(miss_flag, features);

if (auc_val >= auc_thresh) {
label = "MAR ";
detected_code = 2;
} else {
label = "MNAR";
detected_code = 3;
}
}
}

DetectedCodes[i, 1] = detected_code;
print(i + "\t" + missing_count + "\t" + toString(best_pval) + "\t" + auc_val + "\t" + label);
}
}



# Evaluate AUC of a logistic regression that tries to explain the missingness using other columns
# Higher values indicate the dependence of the observed variables (probably MAR)

mar_auc_score = function(Matrix[Double] miss_flag, Matrix[Double] features) return (Double auc_val) {
# need to analyze only full rows
complete_rows = (rowSums(is.na(features)) == 0);

X_clean = removeEmpty(target=features, margin="rows", select=complete_rows);
R_clean = removeEmpty(target=miss_flag, margin="rows", select=complete_rows);

# need to check if we have more than 2 rows to compare
auc_val = 0.5;
if (nrow(R_clean) >= 2 & min(R_clean) != max(R_clean)) {
B = multiLogReg(X=X_clean, Y=R_clean, icpt=2, verbose=FALSE);
[probs, yhat, acc] = multiLogRegPredict(X=X_clean, B=B, Y=R_clean);
auc_val = auc(Y=(R_clean - 1), P=probs[, 2]);
}
}

# Bin a numeric vector into 5 equal-width bins (1..5)
bin_into_5 = function(Matrix[Double] v)
return (Matrix[Double] b)
{
vmin = min(v);
vmax = max(v);
b = floor(((v - vmin) / (vmax - vmin)) * 4.99) + 1;
}

# Chi-square independence test p-value for a contingency table F
chisq_pval = function(Matrix[Double] F)
return (Double p)
{
eps = 1e-10;

row_s = rowSums(F);
col_s = colSums(F);
expected = (row_s %*% col_s) / sum(F);

chi2 = sum((F - expected)^2 / (expected + eps));
dfv = (nrow(F) - 1) * (ncol(F) - 1);

p = 1.0 - cdf(target=chi2, dist="chisq", df=dfv);
}

# Remove column idx from matrix X (keeps column order)
drop_column = function(Matrix[Double] X, Integer idx)
return (Matrix[Double] X_out)
{
p = ncol(X);

if (idx == 1) {
X_out = X[, 2:p];
} else if (idx == p) {
X_out = X[, 1:(p-1)];
} else {
X_out = cbind(X[, 1:(idx-1)], X[, (idx+1):p]);
}
}

1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public enum Builtins {
DEDUP("dedup", true),
DEEPWALK("deepWalk", true),
DET("det", false),
DETECTMISSINGTYPE("detectMissingType", true),
DETECTSCHEMA("detectSchema", false),
DENIALCONSTRAINTS("denialConstraints", true),
DIFFERENCESTATISTICS("differenceStatistics", true),
Expand Down
94 changes: 94 additions & 0 deletions src/test/scripts/functions/builtin/create_dataset.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

apply_mcar = function(Matrix[Double] col_in, Double p_missing, Integer seed) return (Matrix[Double] col_out) {
n = nrow(col_in);
is_missing = (rand(rows=n, cols=1, min=0, max=1, seed=seed) < p_missing);

# create mask for columns
nan_mask = replace(target=is_missing * 1.0, pattern=1, replacement=NaN);
nan_mask = replace(target=nan_mask, pattern=0, replacement=1);

# apply mask
col_out = col_in * nan_mask;
}

apply_mar = function(Matrix[Double] col_in, Matrix[Double] ref1, Matrix[Double] ref2, Integer seed, Double base_prob, Double strength) return (Matrix[Double] col_out) {

ref1_std = scale(ref1, TRUE, TRUE);
ref2_std = scale(ref2, TRUE, TRUE);

# gen prob
logit_base = log(base_prob / (1.0 - base_prob));
logit = logit_base + strength * (1.2 * ref1_std - 0.8 * ref2_std);
p_missing = 1 / (1 + exp(-logit));
p_missing = max(min(p_missing, 0.85), 0.10);

# gen mask
is_missing = (rand(rows=nrow(col_in), cols=1, seed=seed) < p_missing);
nan_mask = replace(target=is_missing * 1.0, pattern=1, replacement=NaN);
nan_mask = replace(target=nan_mask, pattern=0, replacement=1);

# apply mask
col_out = col_in * nan_mask;
}

apply_mnar = function(Matrix[Double] col_in, Double p_missing, Integer seed) return (Matrix[Double] col_out) {

y_norm = (col_in - min(col_in)) / (max(col_in) - min(col_in));

# gen missing
is_missing = (rand(rows=nrow(col_in), cols=1, seed=seed) < (y_norm * p_missing));

# apply Nan
col_out = col_in + (ifelse(is_missing, NaN, 0));
}

get_dataset = function(Integer n, Matrix[Double] GroundTruth = matrix(0,0,0)) return (Matrix[Double] X, Matrix[Double] GT_out) {

# create groundTruth
if (nrow(GroundTruth) == 0) {
# default groundTruth
GT_out = matrix("0 0 0 1 1 2 2 3 3", rows=9, cols=1);
} else {
GT_out = GroundTruth;
}

num_cols = nrow(GT_out);
X = matrix(0, rows=n, cols=0);

# gen anchor columns
anchor1 = rand(rows=n, cols=1, pdf="normal", seed=42);
anchor2 = rand(rows=n, cols=1, pdf="normal", seed=43);

# gen columns
for (i in 1:num_cols) {
code = as.scalar(GT_out[i, 1]);
col_base = matrix(0, rows=n, cols=1);

if (i <= 2) {
if (i == 1) {
col_base = anchor1;
} else {
col_base = anchor2;
}
} else {
# gen noise & correlation
noise = rand(rows=n, cols=1, seed=i*123);
col_base = 0.5 * anchor1 + 0.3 * anchor2 + noise * 0.2;
}

# apply missing type
if (code == 0) {
X = cbind(X, col_base);
}
else if (code == 1) {
X = cbind(X, apply_mcar(col_base, 0.20, i));
}
else if (code == 2) {
X = cbind(X, apply_mar(col_base, anchor1, anchor2, i, 0.15, 3.0));
}
else if (code == 3) {
X = cbind(X, apply_mnar(col_base, 0.25, i));
}
}
print("Generation dataset completed");
}
46 changes: 46 additions & 0 deletions src/test/scripts/functions/builtin/testDetectMissingType.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
source("src/test/scripts/functions/builtin/create_dataset.dml") as gen;

# four different tests
for (t in 1:4) {
print("=========================================================");

n_rows = 5000;
test_name = "";

if (t == 1) {
test_name = "9 column MIX (Default)";
gt_test = matrix("0 0 0 1 1 2 2 3 3", rows=9, cols=1);
}
else if (t == 2) {
test_name = "4 Easy column";
gt_test = matrix("0 1 2 3", rows=4, cols=1);
}
else if (t == 3) {
test_name = "10 unbalanced columns (MAR)";
gt_test = matrix("0 0 2 2 2 2 2 2 2 2", rows=10, cols=1);
}
else {
test_name = "small number of rows (N=150)";
gt_test = matrix("0 1 2 3", rows=4, cols=1);
n_rows = 150;
}

print("Execute TEST " + t + ": " + test_name);

# gen dataset
[data, truth] = gen::get_dataset(n=n_rows, GroundTruth=gt_test);

# analyze dataset
results = detectMissingType(data, 0.05, 0.60);

# check the results
match_matrix = (results == truth);
correct_count = sum(match_matrix);
accuracy = (correct_count / nrow(truth)) * 100;

print("---------------------------------------------------------");
print("Accuracy: " + toString(accuracy) + "%");
}

print("=========================================================");
print("Test suite completed");