import os import cv2 import numpy as np import tensorflow as tf from flask_app.detect.ct_dataset import COVIDxCTDataset IMAGE_INPUT_TENSOR = 'Placeholder:0' CLASS_PRED_TENSOR = 'ArgMax:0' CLASS_PROB_TENSOR = 'softmax_tensor:0' TRAINING_PH_TENSOR = 'is_training:0' CLASS_NAMES = ('Normal', 'Pneumonia', 'COVID-19') def create_session(): config = tf.ConfigProto() config.gpu_options.allow_growth = True #显存按需分配,避免预先分配固定大小的显存造成浪费 sess = tf.Session(config=config) return sess class COVIDNetCTRunner: """Primary training/testing/inference class""" def __init__(self, meta_file, ckpt=None, data_dir=None, input_height=512, input_width=512): self.meta_file = meta_file self.ckpt = ckpt self.input_height = input_height self.input_width = input_width self.data_dir=data_dir if data_dir is None: self.dataset = None else: self.dataset = COVIDxCTDataset( data_dir, image_height=input_height, image_width=input_width, ) def load_graph(self): """Creates new graph and session""" graph = tf.Graph() with graph.as_default(): # Create session and load model sess = create_session() # Load meta file print('Loading meta graph from ' + self.meta_file) saver = tf.train.import_meta_graph(self.meta_file, clear_devices=True) return graph, sess, saver def load_ckpt(self, sess, saver): """Helper for loading weights""" # Load weights if self.ckpt is not None: print('Loading weights from ' + self.ckpt) saver.restore(sess, self.ckpt) def infer(self, image_file, autocrop=False): image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) image = cv2.resize(image, (self.input_width, self.input_height), cv2.INTER_CUBIC) image = image.astype(np.float32)/255.0 image = np.expand_dims(np.stack((image, image, image), axis=-1), axis=0) feed_dict = {IMAGE_INPUT_TENSOR: image, TRAINING_PH_TENSOR: False} graph, sess, saver = self.load_graph() with graph.as_default(): self.load_ckpt(sess, saver) try: sess.graph.get_tensor_by_name(TRAINING_PH_TENSOR) feed_dict[TRAINING_PH_TENSOR] = False except KeyError: pass class_, probs = sess.run([CLASS_PRED_TENSOR, CLASS_PROB_TENSOR], feed_dict=feed_dict) pred_type=CLASS_NAMES[class_[0]] pred_normal = round(probs[0][0],3) pred_pneu = round(probs[0][1],3) pred_covid = round(probs[0][2],3) return pred_type,pred_normal,pred_pneu,pred_covid def detectct(imagepath): model_dir="models/COVID-Net_CT-2_L" meta_name="model.meta" ckpt_name="model" input_height=512 input_width=512 meta_file = os.path.join(model_dir, meta_name) ckpt = os.path.join(model_dir, ckpt_name) runner = COVIDNetCTRunner( meta_file=meta_file, ckpt=ckpt, input_height=input_height, input_width=input_width ) return runner.infer(imagepath)