24 lines
781 B
Python
24 lines
781 B
Python
import os
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
class COVIDxCTDataset:
|
|
"""COVIDx CT dataset class, which handles construction of train/validation datasets"""
|
|
def __init__(self, data_dir, image_height=512, image_width=512):
|
|
# General parameters
|
|
self.data_dir = data_dir
|
|
self.image_height = image_height
|
|
self.image_width = image_width
|
|
|
|
def _make_dataset(self, split_file, batch_size, is_training, balanced=True):
|
|
"""Creates COVIDX-CT dataset for train or val split"""
|
|
files, classes, bboxes = self._get_files(split_file)
|
|
count = len(files)
|
|
dataset = tf.data.Dataset.from_tensor_slices((files, classes, bboxes))
|
|
dataset = dataset.batch(batch_size)
|
|
return dataset, count, batch_size
|
|
|
|
|
|
|
|
|