2024-07-29 11:43:52 +08:00

24 lines
781 B
Python

import os
import numpy as np
import tensorflow as tf
class COVIDxCTDataset:
"""COVIDx CT dataset class, which handles construction of train/validation datasets"""
def __init__(self, data_dir, image_height=512, image_width=512):
# General parameters
self.data_dir = data_dir
self.image_height = image_height
self.image_width = image_width
def _make_dataset(self, split_file, batch_size, is_training, balanced=True):
"""Creates COVIDX-CT dataset for train or val split"""
files, classes, bboxes = self._get_files(split_file)
count = len(files)
dataset = tf.data.Dataset.from_tensor_slices((files, classes, bboxes))
dataset = dataset.batch(batch_size)
return dataset, count, batch_size