|
| 1 | +import logging |
| 2 | +import time |
| 3 | +from typing import List, Optional |
| 4 | + |
| 5 | +import cv2 |
| 6 | +import numpy as np |
| 7 | +import onnxruntime as ort |
| 8 | +from huggingface_hub import hf_hub_download |
| 9 | +from PIL import Image |
| 10 | +import re |
| 11 | + |
| 12 | +# the "contract" classes that a new provider MUST use for its return value. |
| 13 | +from src.ocr.interface import BoundingBox, OcrProvider, Paragraph, Word |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +# --- model configuration --- |
| 18 | +DET_MODEL_REPO = "rtr46/meiki.text.detect.v0" |
| 19 | +DET_MODEL_NAME = "meiki.text.detect.v0.1.960x544.onnx" |
| 20 | +REC_MODEL_REPO = "rtr46/meiki.txt.recognition.v0" |
| 21 | +REC_MODEL_NAME = "meiki.text.rec.v0.960x32.onnx" |
| 22 | + |
| 23 | +# --- pipeline configuration --- |
| 24 | +INPUT_DET_WIDTH = 960 |
| 25 | +INPUT_DET_HEIGHT = 544 |
| 26 | +INPUT_REC_HEIGHT = 32 |
| 27 | +INPUT_REC_WIDTH = 960 |
| 28 | +DET_CONFIDENCE_THRESHOLD = 0.5 |
| 29 | +REC_CONFIDENCE_THRESHOLD = 0.1 |
| 30 | +X_OVERLAP_THRESHOLD = 0.3 |
| 31 | +EPSILON = 1e-6 |
| 32 | + |
| 33 | +JAPANESE_REGEX = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF]') |
| 34 | + |
| 35 | + |
| 36 | +class MeikiOcrProvider(OcrProvider): |
| 37 | + """ |
| 38 | + An OCR provider that uses the high-performance meikiocr pipeline. |
| 39 | + This provider is specifically optimized for recognizing Japanese text from video games. |
| 40 | + """ |
| 41 | + NAME = "meikiocr (local)" |
| 42 | + |
| 43 | + def __init__(self): |
| 44 | + """ |
| 45 | + Initializes the provider and lazy-loads the ONNX models. |
| 46 | + This is called once when the provider is selected in MeikiPop. |
| 47 | + """ |
| 48 | + logger.info(f"initializing {self.NAME} provider...") |
| 49 | + self.det_session = None |
| 50 | + self.rec_session = None |
| 51 | + try: |
| 52 | + det_model_path = hf_hub_download(repo_id=DET_MODEL_REPO, filename=DET_MODEL_NAME) |
| 53 | + rec_model_path = hf_hub_download(repo_id=REC_MODEL_REPO, filename=REC_MODEL_NAME) |
| 54 | + |
| 55 | + # prioritize gpu if available, fallback to cpu. |
| 56 | + available_providers = ort.get_available_providers() |
| 57 | + desired_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| 58 | + providers_to_use = [p for p in desired_providers if p in available_providers] |
| 59 | + ort.set_default_logger_severity(3) # suppress verbose logs |
| 60 | + |
| 61 | + self.det_session = ort.InferenceSession(det_model_path, providers=providers_to_use) |
| 62 | + self.rec_session = ort.InferenceSession(rec_model_path, providers=providers_to_use) |
| 63 | + |
| 64 | + active_provider = self.det_session.get_providers()[0] |
| 65 | + logger.info(f"{self.NAME} initialized successfully, running on: {active_provider}") |
| 66 | + |
| 67 | + except Exception as e: |
| 68 | + logger.error(f"failed to initialize {self.NAME}: {e}", exc_info=True) |
| 69 | + |
| 70 | + def scan(self, image: Image.Image) -> Optional[List[Paragraph]]: |
| 71 | + """ |
| 72 | + Performs OCR on the given image using the full meikiocr pipeline. |
| 73 | + """ |
| 74 | + if not self.det_session or not self.rec_session: |
| 75 | + logger.error(f"{self.NAME} was not initialized correctly. cannot perform scan.") |
| 76 | + return None |
| 77 | + |
| 78 | + try: |
| 79 | + start_time = time.perf_counter() |
| 80 | + |
| 81 | + # convert pil (rgb) image to numpy array for opencv processing. |
| 82 | + image_np = np.array(image.convert("RGB")) |
| 83 | + img_height, img_width = image_np.shape[:2] |
| 84 | + if img_width == 0 or img_height == 0: |
| 85 | + logger.error("invalid image dimensions received.") |
| 86 | + return None |
| 87 | + |
| 88 | + # --- 1. run detection stage --- |
| 89 | + det_input, sx, sy = self._preprocess_for_detection(image_np) |
| 90 | + det_raw = self._run_detection_inference(det_input) |
| 91 | + text_boxes = self._postprocess_detection_results(det_raw, sx, sy) |
| 92 | + |
| 93 | + if not text_boxes: |
| 94 | + return [] |
| 95 | + |
| 96 | + # --- 2. run recognition stage --- |
| 97 | + rec_batch, valid_indices, crop_meta = self._preprocess_for_recognition(image_np, text_boxes) |
| 98 | + if rec_batch is None: |
| 99 | + return [] |
| 100 | + |
| 101 | + rec_raw = self._run_recognition_inference(rec_batch) |
| 102 | + ocr_results = self._postprocess_recognition_results(rec_raw, valid_indices, crop_meta, len(text_boxes)) |
| 103 | + |
| 104 | + # --- 3. transform data to meikipop's format --- |
| 105 | + paragraphs = self._to_meikipop_paragraphs(ocr_results, img_width, img_height) |
| 106 | + |
| 107 | + duration = time.perf_counter() - start_time |
| 108 | + logger.info(f"{self.NAME} processed image in {duration:.3f}s, found {len(paragraphs)} paragraphs.") |
| 109 | + |
| 110 | + return paragraphs |
| 111 | + |
| 112 | + except Exception as e: |
| 113 | + logger.error(f"an error occurred in {self.NAME}: {e}", exc_info=True) |
| 114 | + return None # returning none indicates a failure. |
| 115 | + |
| 116 | + def _to_normalized_bbox(self, bbox_pixels: list, img_width: int, img_height: int) -> BoundingBox: |
| 117 | + """converts an [x1, y1, x2, y2] pixel bbox to a normalized meikipop BoundingBox.""" |
| 118 | + x1, y1, x2, y2 = bbox_pixels |
| 119 | + box_w, box_h = x2 - x1, y2 - y1 |
| 120 | + |
| 121 | + center_x = (x1 + box_w / 2) / img_width |
| 122 | + center_y = (y1 + box_h / 2) / img_height |
| 123 | + norm_w = box_w / img_width |
| 124 | + norm_h = box_h / img_height |
| 125 | + |
| 126 | + return BoundingBox(center_x, center_y, norm_w, norm_h) |
| 127 | + |
| 128 | + def _to_meikipop_paragraphs(self, ocr_results: list, img_width: int, img_height: int) -> List[Paragraph]: |
| 129 | + """converts the final meikiocr result list into meikipop's Paragraph format.""" |
| 130 | + paragraphs: List[Paragraph] = [] |
| 131 | + for line_result in ocr_results: |
| 132 | + full_text = line_result.get("text", "").strip() |
| 133 | + chars = line_result.get("chars", []) |
| 134 | + if not full_text or not chars or not JAPANESE_REGEX.search(full_text): |
| 135 | + continue |
| 136 | + |
| 137 | + # create word objects for each character (best for precise lookups). |
| 138 | + words_in_para: List[Word] = [] |
| 139 | + for char_info in chars: |
| 140 | + char_box = self._to_normalized_bbox(char_info['bbox'], img_width, img_height) |
| 141 | + words_in_para.append(Word(text=char_info['char'], separator="", box=char_box)) |
| 142 | + |
| 143 | + # meikiocr doesn't provide a line-level box, so we must compute it |
| 144 | + # by finding the union of all character boxes in the line. |
| 145 | + min_x = min(c['bbox'][0] for c in chars) |
| 146 | + min_y = min(c['bbox'][1] for c in chars) |
| 147 | + max_x = max(c['bbox'][2] for c in chars) |
| 148 | + max_y = max(c['bbox'][3] for c in chars) |
| 149 | + line_pixel_bbox = [min_x, min_y, max_x, max_y] |
| 150 | + line_box = self._to_normalized_bbox(line_pixel_bbox, img_width, img_height) |
| 151 | + |
| 152 | + # meikiocr currently only supports horizontal text. |
| 153 | + is_vertical = False |
| 154 | + |
| 155 | + paragraph = Paragraph( |
| 156 | + full_text=full_text, |
| 157 | + words=words_in_para, |
| 158 | + box=line_box, |
| 159 | + is_vertical=is_vertical |
| 160 | + ) |
| 161 | + paragraphs.append(paragraph) |
| 162 | + |
| 163 | + return paragraphs |
| 164 | + |
| 165 | + # --- meikiocr pipeline methods (adapted from meiki_ocr.py) --- |
| 166 | + |
| 167 | + def _preprocess_for_detection(self, image: np.ndarray): |
| 168 | + h_orig, w_orig = image.shape[:2] |
| 169 | + resized = cv2.resize(image, (INPUT_DET_WIDTH, INPUT_DET_HEIGHT), interpolation=cv2.INTER_LINEAR) |
| 170 | + tensor = resized.astype(np.float32) / 255.0 |
| 171 | + tensor = np.transpose(tensor, (2, 0, 1)) |
| 172 | + tensor = np.expand_dims(tensor, axis=0) |
| 173 | + return tensor, w_orig / INPUT_DET_WIDTH, h_orig / INPUT_DET_HEIGHT |
| 174 | + |
| 175 | + def _run_detection_inference(self, tensor: np.ndarray): |
| 176 | + inputs = { |
| 177 | + self.det_session.get_inputs()[0].name: tensor, |
| 178 | + self.det_session.get_inputs()[1].name: np.array([[INPUT_DET_WIDTH, INPUT_DET_HEIGHT]], dtype=np.int64) |
| 179 | + } |
| 180 | + return self.det_session.run(None, inputs) |
| 181 | + |
| 182 | + def _postprocess_detection_results(self, raw_outputs: list, scale_x: float, scale_y: float): |
| 183 | + _, boxes, scores = raw_outputs |
| 184 | + boxes, scores = boxes[0], scores[0] |
| 185 | + text_boxes = [] |
| 186 | + for box, score in zip(boxes, scores): |
| 187 | + if score < DET_CONFIDENCE_THRESHOLD: continue |
| 188 | + x1, y1, x2, y2 = box |
| 189 | + text_boxes.append({'bbox': [ |
| 190 | + max(0, int(x1 * scale_x)), |
| 191 | + max(0, int(y1 * scale_y)), |
| 192 | + max(0, int(x2 * scale_x)), |
| 193 | + max(0, int(y2 * scale_y)) |
| 194 | + ]}) |
| 195 | + return text_boxes |
| 196 | + |
| 197 | + def _preprocess_for_recognition(self, image: np.ndarray, text_boxes: list): |
| 198 | + tensors, valid_indices, crop_meta = [], [], [] |
| 199 | + for i, tb in enumerate(text_boxes): |
| 200 | + x1, y1, x2, y2 = tb['bbox'] |
| 201 | + w, h = x2 - x1, y2 - y1 |
| 202 | + if w < h or w <= 0 or h <= 0: continue |
| 203 | + crop = image[y1:y2, x1:x2] |
| 204 | + ch, cw = crop.shape[:2] |
| 205 | + nh, nw = INPUT_REC_HEIGHT, int(round(cw * (INPUT_REC_HEIGHT / ch))) |
| 206 | + if nw > INPUT_REC_WIDTH: |
| 207 | + scale = INPUT_REC_WIDTH / nw |
| 208 | + nw, nh = INPUT_REC_WIDTH, int(round(nh * scale)) |
| 209 | + resized = cv2.resize(crop, (nw, nh), interpolation=cv2.INTER_LINEAR) |
| 210 | + pw, ph = INPUT_REC_WIDTH - nw, INPUT_REC_HEIGHT - nh |
| 211 | + padded = np.pad(resized, ((0, ph), (0, pw), (0, 0)), constant_values=0) |
| 212 | + tensor = (padded.astype(np.float32) / 255.0) |
| 213 | + tensor = np.transpose(tensor, (2, 0, 1)) |
| 214 | + tensors.append(tensor) |
| 215 | + valid_indices.append(i) |
| 216 | + crop_meta.append({'orig_bbox': [x1, y1, x2, y2], 'effective_w': nw}) |
| 217 | + if not tensors: return None, [], [] |
| 218 | + return np.stack(tensors, axis=0), valid_indices, crop_meta |
| 219 | + |
| 220 | + def _run_recognition_inference(self, batch_tensor: np.ndarray): |
| 221 | + inputs = { |
| 222 | + "images": batch_tensor, |
| 223 | + "orig_target_sizes": np.array([[INPUT_REC_WIDTH, INPUT_REC_HEIGHT]], dtype=np.int64) |
| 224 | + } |
| 225 | + return self.rec_session.run(None, inputs) |
| 226 | + |
| 227 | + def _postprocess_recognition_results(self, raw_outputs: list, valid_indices: list, crop_meta: list, num_boxes: int): |
| 228 | + labels_batch, boxes_batch, scores_batch = raw_outputs |
| 229 | + results = [{'text': '', 'chars': []} for _ in range(num_boxes)] |
| 230 | + for i, (labels, boxes, scores) in enumerate(zip(labels_batch, boxes_batch, scores_batch)): |
| 231 | + meta = crop_meta[i] |
| 232 | + gx1, gy1, gx2, gy2 = meta['orig_bbox'] |
| 233 | + cw, ch = gx2 - gx1, gy2 - gy1 |
| 234 | + ew = meta['effective_w'] |
| 235 | + |
| 236 | + candidates = [] |
| 237 | + for lbl, box, scr in zip(labels, boxes, scores): |
| 238 | + if scr < REC_CONFIDENCE_THRESHOLD: continue |
| 239 | + |
| 240 | + char = chr(lbl) |
| 241 | + rx1, ry1, rx2, ry2 = box |
| 242 | + rx1, rx2 = min(rx1, ew), min(rx2, ew) |
| 243 | + |
| 244 | + # map: recognition space -> crop space -> global image |
| 245 | + cx1 = (rx1 / ew) * cw |
| 246 | + cx2 = (rx2 / ew) * cw |
| 247 | + cy1 = (ry1 / INPUT_REC_HEIGHT) * ch |
| 248 | + cy2 = (ry2 / INPUT_REC_HEIGHT) * ch |
| 249 | + |
| 250 | + gx1_char = gx1 + int(cx1) |
| 251 | + gy1_char = gy1 + int(cy1) |
| 252 | + gx2_char = gx1 + int(cx2) |
| 253 | + gy2_char = gy1 + int(cy2) |
| 254 | + |
| 255 | + candidates.append({ |
| 256 | + 'char': char, |
| 257 | + 'bbox': [gx1_char, gy1_char, gx2_char, gy2_char], |
| 258 | + 'conf': float(scr), |
| 259 | + 'x_interval': (gx1_char, gx2_char) |
| 260 | + }) |
| 261 | + |
| 262 | + # sort by confidence (descending) to prepare for deduplication |
| 263 | + candidates.sort(key=lambda c: c['conf'], reverse=True) |
| 264 | + |
| 265 | + # spatial deduplication on x-axis (non-maximum suppression) |
| 266 | + accepted = [] |
| 267 | + for cand in candidates: |
| 268 | + x1_c, x2_c = cand['x_interval'] |
| 269 | + width_c = x2_c - x1_c + EPSILON |
| 270 | + keep = True |
| 271 | + for acc in accepted: |
| 272 | + x1_a, x2_a = acc['x_interval'] |
| 273 | + overlap = max(0, min(x2_c, x2_a) - max(x1_c, x1_a)) |
| 274 | + if overlap / width_c > X_OVERLAP_THRESHOLD: |
| 275 | + keep = False |
| 276 | + break |
| 277 | + if keep: |
| 278 | + accepted.append(cand) |
| 279 | + |
| 280 | + # sort by x for final reading order |
| 281 | + accepted.sort(key=lambda c: c['x_interval'][0]) |
| 282 | + |
| 283 | + text = ''.join(c['char'] for c in accepted) |
| 284 | + # keep the confidence score in the final output as it can be useful |
| 285 | + final_chars = [{'char': c['char'], 'bbox': c['bbox'], 'conf': c['conf']} for c in accepted] |
| 286 | + |
| 287 | + results[valid_indices[i]] = {'text': text, 'chars': final_chars} |
| 288 | + |
| 289 | + return results |
0 commit comments