Source code for dl_data_validation_toolset.data_tests.labels

import numpy as np
import logging
from dl_data_validation_toolset.framework import base_test


[docs]class LabelTests(base_test.BaseTest): logger = logging.getLogger('data_tests.labels')
[docs] def test_label_exists(self): labels = self._file['label/type'] logging.debug("Labels: {}".format(labels)) valid = int(labels is not None) valid += int(len(labels)) return {'N Labels': len(labels)}, valid
[docs] def test_nonzero_labels(self): labels = self._file['label/type'] null_vectors = 0 valid = 2 for labelvec in labels: if np.max(labelvec) == 0: self.logger.debug("Found null vector: {}".format(labelvec)) valid = 1 null_vectors += 1 if null_vectors == len(labels): valid = 0 return {'null_vectors': null_vectors}, valid
[docs] def test_label_diversity(self): labels = self._file['label/type'] label_accumulator = np.ndarray(shape=(len(labels[0]))) for labelvec in labels: label_accumulator += labelvec result = {} result['valid'] = np.max(label_accumulator) == 0 result['valid'] = result['valid'] or np.min(label_accumulator) == 0 result['labels'] = label_accumulator return result, 2