95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from flask_app.detect.ct_dataset import COVIDxCTDataset
|
|
IMAGE_INPUT_TENSOR = 'Placeholder:0'
|
|
CLASS_PRED_TENSOR = 'ArgMax:0'
|
|
CLASS_PROB_TENSOR = 'softmax_tensor:0'
|
|
TRAINING_PH_TENSOR = 'is_training:0'
|
|
CLASS_NAMES = ('Normal', 'Pneumonia', 'COVID-19')
|
|
|
|
def create_session():
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True #显存按需分配,避免预先分配固定大小的显存造成浪费
|
|
sess = tf.Session(config=config)
|
|
return sess
|
|
|
|
class COVIDNetCTRunner:
|
|
"""Primary training/testing/inference class"""
|
|
def __init__(self, meta_file, ckpt=None, data_dir=None, input_height=512, input_width=512):
|
|
self.meta_file = meta_file
|
|
self.ckpt = ckpt
|
|
self.input_height = input_height
|
|
self.input_width = input_width
|
|
self.data_dir=data_dir
|
|
if data_dir is None:
|
|
self.dataset = None
|
|
else:
|
|
self.dataset = COVIDxCTDataset(
|
|
data_dir,
|
|
image_height=input_height,
|
|
image_width=input_width,
|
|
)
|
|
|
|
def load_graph(self):
|
|
"""Creates new graph and session"""
|
|
graph = tf.Graph()
|
|
with graph.as_default():
|
|
# Create session and load model
|
|
sess = create_session()
|
|
|
|
# Load meta file
|
|
print('Loading meta graph from ' + self.meta_file)
|
|
saver = tf.train.import_meta_graph(self.meta_file, clear_devices=True)
|
|
return graph, sess, saver
|
|
|
|
def load_ckpt(self, sess, saver):
|
|
"""Helper for loading weights"""
|
|
# Load weights
|
|
if self.ckpt is not None:
|
|
print('Loading weights from ' + self.ckpt)
|
|
saver.restore(sess, self.ckpt)
|
|
|
|
|
|
def infer(self, image_file, autocrop=False):
|
|
image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
|
|
image = cv2.resize(image, (self.input_width, self.input_height), cv2.INTER_CUBIC)
|
|
image = image.astype(np.float32)/255.0
|
|
image = np.expand_dims(np.stack((image, image, image), axis=-1), axis=0)
|
|
|
|
feed_dict = {IMAGE_INPUT_TENSOR: image, TRAINING_PH_TENSOR: False}
|
|
graph, sess, saver = self.load_graph()
|
|
with graph.as_default():
|
|
self.load_ckpt(sess, saver)
|
|
try:
|
|
sess.graph.get_tensor_by_name(TRAINING_PH_TENSOR)
|
|
feed_dict[TRAINING_PH_TENSOR] = False
|
|
except KeyError:
|
|
pass
|
|
|
|
class_, probs = sess.run([CLASS_PRED_TENSOR, CLASS_PROB_TENSOR], feed_dict=feed_dict)
|
|
pred_type=CLASS_NAMES[class_[0]]
|
|
pred_normal = round(probs[0][0],3)
|
|
pred_pneu = round(probs[0][1],3)
|
|
pred_covid = round(probs[0][2],3)
|
|
return pred_type,pred_normal,pred_pneu,pred_covid
|
|
|
|
|
|
def detectct(imagepath):
|
|
model_dir="models/COVID-Net_CT-2_L"
|
|
meta_name="model.meta"
|
|
ckpt_name="model"
|
|
input_height=512
|
|
input_width=512
|
|
meta_file = os.path.join(model_dir, meta_name)
|
|
ckpt = os.path.join(model_dir, ckpt_name)
|
|
runner = COVIDNetCTRunner(
|
|
meta_file=meta_file,
|
|
ckpt=ckpt,
|
|
input_height=input_height,
|
|
input_width=input_width
|
|
)
|
|
return runner.infer(imagepath)
|
|
|