Skip to content

Commit

Permalink
added sliding window for large image inference (#12152)
Browse files Browse the repository at this point in the history
added sliding window for large image inference
  • Loading branch information
aspaul20 committed May 24, 2024
1 parent 28f7a96 commit 965f569
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 3 deletions.
16 changes: 16 additions & 0 deletions doc/doc_en/slice_en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Slice Operator
If you have a very large image/document that you would like to run PaddleOCR (detection and recognition) on, you can use the slice operation as follows:

`ocr_inst = PaddleOCR(**ocr_settings)`
`results = ocr_inst.ocr(img, det=True,rec=True, slice=slice, cls=False,bin=False,inv=False,alpha_color=False)`

where
`slice = {'horizontal_stride': h_stride, 'vertical_stride':v_stride, 'merge_x_thres':x_thres, 'merge_y_thres': y_thres}`

Here, `h_stride`, `v_stride`, `x_thres`, and `y_thres` are user-configurable values and need to be set manually. The way the `slice` operator works is that it runs a sliding window across the large input image, creating slices of it and runs the OCR algorithms on it.

The fragmented slice-level results are then merged together to output image-level detection and recognition results. The horizontal and vertical strides cannot be lower than a certain limit (as too low values would create so many slices it would be very computationally expensive to get results for each of them). However, as an example the recommended values for an image with dimensions 6616x14886 would be as follows.

`slice = {'horizontal_stride': 300, 'vertical_stride':500, 'merge_x_thres':50, 'merge_y_thres': 35}`

All slice-level detections with bounding boxes as close as `merge_x_thres` and `merge_y_thres` will be merged together.
4 changes: 3 additions & 1 deletion paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def ocr(
bin=False,
inv=False,
alpha_color=(255, 255, 255),
slice={},
):
"""
OCR with PaddleOCR
Expand All @@ -691,6 +692,7 @@ def ocr(
bin: binarize image to black and white. Default is False.
inv: invert image colors. Default is False.
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
slice: use sliding window inference for large images, det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres] (See doc/doc_en/slice_en.md). Default is {}.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
Expand Down Expand Up @@ -723,7 +725,7 @@ def preprocess_image(_image):
ocr_res = []
for idx, img in enumerate(imgs):
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls)
dt_boxes, rec_res, _ = self.__call__(img, cls, slice)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
Expand Down
35 changes: 33 additions & 2 deletions tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
draw_ocr_box_txt,
get_rotate_crop_image,
get_minarea_rect_crop,
slice_generator,
merge_fragmented,
)

logger = get_logger()
Expand Down Expand Up @@ -71,7 +73,7 @@ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num

def __call__(self, img, cls=True):
def __call__(self, img, cls=True, slice={}):
time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}

if img is None:
Expand All @@ -80,7 +82,32 @@ def __call__(self, img, cls=True):

start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
if slice:
slice_gen = slice_generator(
img,
horizontal_stride=slice["horizontal_stride"],
vertical_stride=slice["vertical_stride"],
)
elapsed = []
dt_slice_boxes = []
for slice_crop, v_start, h_start in slice_gen:
dt_boxes, elapse = self.text_detector(slice_crop)
if dt_boxes.size:
dt_boxes[:, :, 0] += h_start
dt_boxes[:, :, 1] += v_start
dt_slice_boxes.append(dt_boxes)
elapsed.append(elapse)
dt_boxes = np.concatenate(dt_slice_boxes)

dt_boxes = merge_fragmented(
boxes=dt_boxes,
x_threshold=slice["merge_x_thres"],
y_threshold=slice["merge_y_thres"],
)
elapse = sum(elapsed)
else:
dt_boxes, elapse = self.text_detector(img)

time_dict["det"] = elapse

if dt_boxes is None:
Expand Down Expand Up @@ -109,6 +136,10 @@ def __call__(self, img, cls=True):
logger.debug(
"cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
)
if len(img_crop_list) > 1000:
logger.debug(
f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
)

rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict["rec"] = elapse
Expand Down
98 changes: 98 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,104 @@ def get_minarea_rect_crop(img, points):
return crop_img


def slice_generator(image, horizontal_stride, vertical_stride, maximum_slices=500):
if not isinstance(image, np.ndarray):
image = np.array(image)

image_h, image_w = image.shape[:2]
vertical_num_slices = (image_h + vertical_stride - 1) // vertical_stride
horizontal_num_slices = (image_w + horizontal_stride - 1) // horizontal_stride

assert (
vertical_num_slices > 0
), f"Invalid number ({vertical_num_slices}) of vertical slices"

assert (
horizontal_num_slices > 0
), f"Invalid number ({horizontal_num_slices}) of horizontal slices"

if vertical_num_slices >= maximum_slices:
recommended_vertical_stride = max(1, image_h // maximum_slices) + 1
assert (
False
), f"Too computationally expensive with {vertical_num_slices} slices, try a higher vertical stride (recommended minimum: {recommended_vertical_stride})"

if horizontal_num_slices >= maximum_slices:
recommended_horizontal_stride = max(1, image_w // maximum_slices) + 1
assert (
False
), f"Too computationally expensive with {horizontal_num_slices} slices, try a higher horizontal stride (recommended minimum: {recommended_horizontal_stride})"

for v_slice_idx in range(vertical_num_slices):
v_start = max(0, (v_slice_idx * vertical_stride))
v_end = min(((v_slice_idx + 1) * vertical_stride), image_h)
vertical_slice = image[v_start:v_end, :]
for h_slice_idx in range(horizontal_num_slices):
h_start = max(0, (h_slice_idx * horizontal_stride))
h_end = min(((h_slice_idx + 1) * horizontal_stride), image_w)
horizontal_slice = vertical_slice[:, h_start:h_end]

yield (horizontal_slice, v_start, h_start)


def calculate_box_extents(box):
min_x = box[0][0]
max_x = box[1][0]
min_y = box[0][1]
max_y = box[2][1]
return min_x, max_x, min_y, max_y


def merge_boxes(box1, box2, x_threshold, y_threshold):
min_x1, max_x1, min_y1, max_y1 = calculate_box_extents(box1)
min_x2, max_x2, min_y2, max_y2 = calculate_box_extents(box2)

if (
abs(min_y1 - min_y2) <= y_threshold
and abs(max_y1 - max_y2) <= y_threshold
and abs(max_x1 - min_x2) <= x_threshold
):
new_xmin = min(min_x1, min_x2)
new_xmax = max(max_x1, max_x2)
new_ymin = min(min_y1, min_y2)
new_ymax = max(max_y1, max_y2)
return [
[new_xmin, new_ymin],
[new_xmax, new_ymin],
[new_xmax, new_ymax],
[new_xmin, new_ymax],
]
else:
return None


def merge_fragmented(boxes, x_threshold=10, y_threshold=10):
merged_boxes = []
visited = set()

for i, box1 in enumerate(boxes):
if i in visited:
continue

merged_box = [point[:] for point in box1]

for j, box2 in enumerate(boxes[i + 1 :], start=i + 1):
if j not in visited:
merged_result = merge_boxes(
merged_box, box2, x_threshold=x_threshold, y_threshold=y_threshold
)
if merged_result:
merged_box = merged_result
visited.add(j)

merged_boxes.append(merged_box)

if len(merged_boxes) == len(boxes):
return np.array(merged_boxes)
else:
return merge_fragmented(merged_boxes, x_threshold, y_threshold)


def check_gpu(use_gpu):
if use_gpu and (
not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
Expand Down

0 comments on commit 965f569

Please sign in to comment.