Skip to content
18 changes: 13 additions & 5 deletions yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
if "model_state_dict" in weights:
weights = weights["model_state_dict"]
if "state_dict" in weights:
weights = weights["state_dict"]

model_state_dict = self.model.state_dict()
model_state_dict = self.state_dict()

# TODO1: autoload old version weight
# TODO2: weight transform if num_class difference

error_dict = {"Mismatch": set(), "Not Found": set()}
for model_key, model_weight in model_state_dict.items():
if model_key not in weights:

weights_key = model_key
if weights_key not in weights: # .ckpt
weights_key = "model." + model_key
if weights_key not in weights: # .pt old
weights_key = model_key[6:]
if weights_key not in weights:
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
continue
if model_weight.shape != weights[model_key].shape:
if model_weight.shape != weights[weights_key].shape:
error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
continue
model_state_dict[model_key] = weights[model_key]
model_state_dict[model_key] = weights[weights_key]

for error_name, error_set in error_dict.items():
for weight_name in error_set:
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")

self.model.load_state_dict(model_state_dict)
self.load_state_dict(model_state_dict)


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
Expand Down
24 changes: 15 additions & 9 deletions yolo/tools/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st


class StreamDataLoader:
def __init__(self, data_cfg: DataConfig):
def __init__(self, data_cfg: DataConfig, asynchronous: bool = True):
self.source = data_cfg.source
self.running = True
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
Expand All @@ -249,8 +249,12 @@ def __init__(self, data_cfg: DataConfig):
else:
self.source = Path(self.source)
self.queue = Queue()
self.thread = Thread(target=self.load_source)
self.thread.start()

if asynchronous:
self.thread = Thread(target=self.load_source)
self.thread.start()
else:
self.load_source()

def load_source(self):
if self.source.is_dir(): # image folder
Expand All @@ -272,20 +276,22 @@ def process_image(self, image_path):
image = Image.open(image_path).convert("RGB")
if image is None:
raise ValueError(f"Error loading image: {image_path}")
self.process_frame(image)
self.process_frame(image, image_path)

def load_video_file(self, video_path):
import cv2

cap = cv2.VideoCapture(str(video_path))
frame_idx = 0
while self.running:
ret, frame = cap.read()
if not ret:
break
self.process_frame(frame)
self.process_frame(frame, f"{video_path.stem}_frame{frame_idx:04d}.png")
frame_idx += 1
cap.release()

def process_frame(self, frame):
def process_frame(self, frame, image_path):
if isinstance(frame, np.ndarray):
# TODO: we don't need cv2
import cv2
Expand All @@ -297,9 +303,9 @@ def process_frame(self, frame):
frame = frame[None]
rev_tensor = rev_tensor[None]
if not self.is_stream:
self.queue.put((frame, rev_tensor, origin_frame))
self.queue.put((frame, rev_tensor, origin_frame, image_path))
else:
self.current_frame = (frame, rev_tensor, origin_frame)
self.current_frame = (frame, rev_tensor, origin_frame, image_path)

def __iter__(self) -> Generator[Tensor, None, None]:
return self
Expand All @@ -310,7 +316,7 @@ def __next__(self) -> Tensor:
if not ret:
self.stop()
raise StopIteration
self.process_frame(frame)
self.process_frame(frame, "stream_frame.png")
return self.current_frame
else:
try:
Expand Down
22 changes: 19 additions & 3 deletions yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
Expand Down Expand Up @@ -112,7 +112,9 @@ def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
# StreamDataLoader has to be synchronous, otherwise not all images are loaded
# TODO: Make this load in parallel
self.predict_loader = StreamDataLoader(cfg.task.data, asynchronous=False)

def setup(self, stage):
self.vec2box = create_converter(
Expand All @@ -124,15 +126,29 @@ def predict_dataloader(self):
return self.predict_loader

def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
images, rev_tensor, origin_frame, image_path = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)

if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
fps = None

if getattr(self.cfg.task, "save_predict", None):
self._save_image(img, batch_idx)

output_txt_file = Path(getattr(self.cfg, "out_path")) / f"results.txt"

# save predics to file img.name .txt, space separated
with open(output_txt_file, "a") as f:
for bboxes in predicts:
for bbox in bboxes:
class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
f.write(f"{image_path.name} {int(class_id)} {x_min} {y_min} {x_max} {y_max} {conf[0]}\n")

print(f"💾 Saved predictions at {output_txt_file}")

return img, fps

def _save_image(self, img, batch_idx):
Expand Down
3 changes: 2 additions & 1 deletion yolo/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx:
epoch_descript = "[cyan]Train [white]|"
batch_descript = "[green]Train [white]|"
metrics = self.get_metrics(trainer, pl_module)
metrics.pop("v_num")
if "v_num" in metrics:
metrics.pop("v_num")
for metrics_name, metrics_val in metrics.items():
if "Loss_step" in metrics_name:
epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
Expand Down
Loading