2024-07-29 11:43:52 +08:00

95 lines
3.2 KiB
Python

import os
import cv2
import numpy as np
import tensorflow as tf
from flask_app.detect.ct_dataset import COVIDxCTDataset
IMAGE_INPUT_TENSOR = 'Placeholder:0'
CLASS_PRED_TENSOR = 'ArgMax:0'
CLASS_PROB_TENSOR = 'softmax_tensor:0'
TRAINING_PH_TENSOR = 'is_training:0'
CLASS_NAMES = ('Normal', 'Pneumonia', 'COVID-19')
def create_session():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True #显存按需分配,避免预先分配固定大小的显存造成浪费
sess = tf.Session(config=config)
return sess
class COVIDNetCTRunner:
"""Primary training/testing/inference class"""
def __init__(self, meta_file, ckpt=None, data_dir=None, input_height=512, input_width=512):
self.meta_file = meta_file
self.ckpt = ckpt
self.input_height = input_height
self.input_width = input_width
self.data_dir=data_dir
if data_dir is None:
self.dataset = None
else:
self.dataset = COVIDxCTDataset(
data_dir,
image_height=input_height,
image_width=input_width,
)
def load_graph(self):
"""Creates new graph and session"""
graph = tf.Graph()
with graph.as_default():
# Create session and load model
sess = create_session()
# Load meta file
print('Loading meta graph from ' + self.meta_file)
saver = tf.train.import_meta_graph(self.meta_file, clear_devices=True)
return graph, sess, saver
def load_ckpt(self, sess, saver):
"""Helper for loading weights"""
# Load weights
if self.ckpt is not None:
print('Loading weights from ' + self.ckpt)
saver.restore(sess, self.ckpt)
def infer(self, image_file, autocrop=False):
image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (self.input_width, self.input_height), cv2.INTER_CUBIC)
image = image.astype(np.float32)/255.0
image = np.expand_dims(np.stack((image, image, image), axis=-1), axis=0)
feed_dict = {IMAGE_INPUT_TENSOR: image, TRAINING_PH_TENSOR: False}
graph, sess, saver = self.load_graph()
with graph.as_default():
self.load_ckpt(sess, saver)
try:
sess.graph.get_tensor_by_name(TRAINING_PH_TENSOR)
feed_dict[TRAINING_PH_TENSOR] = False
except KeyError:
pass
class_, probs = sess.run([CLASS_PRED_TENSOR, CLASS_PROB_TENSOR], feed_dict=feed_dict)
pred_type=CLASS_NAMES[class_[0]]
pred_normal = round(probs[0][0],3)
pred_pneu = round(probs[0][1],3)
pred_covid = round(probs[0][2],3)
return pred_type,pred_normal,pred_pneu,pred_covid
def detectct(imagepath):
model_dir="models/COVID-Net_CT-2_L"
meta_name="model.meta"
ckpt_name="model"
input_height=512
input_width=512
meta_file = os.path.join(model_dir, meta_name)
ckpt = os.path.join(model_dir, ckpt_name)
runner = COVIDNetCTRunner(
meta_file=meta_file,
ckpt=ckpt,
input_height=input_height,
input_width=input_width
)
return runner.infer(imagepath)