Source code for airtest.core.cv

#!/usr/bin/env python
# -*- coding: utf-8 -*-

""""Airtest图像识别专用."""

import os
import sys
import time
import types
from six import PY3
from copy import deepcopy

from airtest import aircv
from airtest.aircv import cv2
from airtest.core.helper import G, logwrap
from airtest.core.settings import Settings as ST  # noqa
from airtest.core.error import TargetNotFoundError, InvalidMatchingMethodError
from airtest.utils.transform import TargetPos

from airtest.aircv.template_matching import TemplateMatching
from airtest.aircv.multiscale_template_matching import MultiScaleTemplateMatching,MultiScaleTemplateMatchingPre
from airtest.aircv.keypoint_matching import KAZEMatching, BRISKMatching, AKAZEMatching, ORBMatching
from airtest.aircv.keypoint_matching_contrib import SIFTMatching, SURFMatching, BRIEFMatching

MATCHING_METHODS = {
    "tpl": TemplateMatching,
    "mstpl": MultiScaleTemplateMatchingPre,
    "gmstpl": MultiScaleTemplateMatching,
    "kaze": KAZEMatching,
    "brisk": BRISKMatching,
    "akaze": AKAZEMatching,
    "orb": ORBMatching,
    "sift": SIFTMatching,
    "surf": SURFMatching,
    "brief": BRIEFMatching,
}


[docs]@logwrap def loop_find(query, timeout=ST.FIND_TIMEOUT, threshold=None, interval=0.5, intervalfunc=None): """ Search for image template in the screen until timeout Args: query: image template to be found in screenshot timeout: time interval how long to look for the image template threshold: default is None interval: sleep interval before next attempt to find the image template intervalfunc: function that is executed after unsuccessful attempt to find the image template Raises: TargetNotFoundError: when image template is not found in screenshot Returns: TargetNotFoundError if image template not found, otherwise returns the position where the image template has been found in screenshot """ G.LOGGING.info("Try finding: %s", query) start_time = time.time() while True: screen = G.DEVICE.snapshot(filename=None, quality=ST.SNAPSHOT_QUALITY) if screen is None: G.LOGGING.warning("Screen is None, may be locked") else: if threshold: query.threshold = threshold match_pos = query.match_in(screen) if match_pos: try_log_screen(screen) return match_pos if intervalfunc is not None: intervalfunc() # 超时则raise,未超时则进行下次循环: if (time.time() - start_time) > timeout: try_log_screen(screen) raise TargetNotFoundError('Picture %s not found in screen' % query) else: time.sleep(interval)
[docs]@logwrap def try_log_screen(screen=None, quality=None, max_size=None): """ Save screenshot to file Args: screen: screenshot to be saved quality: The image quality, default is ST.SNAPSHOT_QUALITY max_size: the maximum size of the picture, e.g 1200 Returns: {"screen": filename, "resolution": aircv.get_resolution(screen)} """ if not ST.LOG_DIR or not ST.SAVE_IMAGE: return if not quality: quality = ST.SNAPSHOT_QUALITY if not max_size: max_size = ST.IMAGE_MAXSIZE if screen is None: screen = G.DEVICE.snapshot(quality=quality) filename = "%(time)d.jpg" % {'time': time.time() * 1000} filepath = os.path.join(ST.LOG_DIR, filename) if screen is not None: aircv.imwrite(filepath, screen, quality, max_size=max_size) return {"screen": filename, "resolution": aircv.get_resolution(screen)} return None
[docs]class Template(object): """ picture as touch/swipe/wait/exists target and extra info for cv match filename: pic filename target_pos: ret which pos in the pic record_pos: pos in screen when recording resolution: screen resolution when recording rgb: 识别结果是否使用rgb三通道进行校验. scale_max: 多尺度模板匹配最大范围. scale_step: 多尺度模板匹配搜索步长. """ def __init__(self, filename, threshold=None, target_pos=TargetPos.MID, record_pos=None, resolution=(), rgb=False, scale_max=800, scale_step=0.005): self.filename = filename self._filepath = None self.threshold = threshold or ST.THRESHOLD self.target_pos = target_pos self.record_pos = record_pos self.resolution = resolution self.rgb = rgb self.scale_max = scale_max self.scale_step = scale_step @property def filepath(self): if self._filepath: return self._filepath for dirname in G.BASEDIR: filepath = os.path.join(dirname, self.filename) if os.path.isfile(filepath): self._filepath = filepath return self._filepath return self.filename def __repr__(self): filepath = self.filepath if PY3 else self.filepath.encode(sys.getfilesystemencoding()) return "Template(%s)" % filepath
[docs] def match_in(self, screen): match_result = self._cv_match(screen) G.LOGGING.debug("match result: %s", match_result) if not match_result: return None focus_pos = TargetPos().getXY(match_result, self.target_pos) return focus_pos
[docs] def match_all_in(self, screen): image = self._imread() image = self._resize_image(image, screen, ST.RESIZE_METHOD) return self._find_all_template(image, screen)
@logwrap def _cv_match(self, screen): # in case image file not exist in current directory: ori_image = self._imread() image = self._resize_image(ori_image, screen, ST.RESIZE_METHOD) ret = None for method in ST.CVSTRATEGY: # get function definition and execute: func = MATCHING_METHODS.get(method, None) if func is None: raise InvalidMatchingMethodError("Undefined method in CVSTRATEGY: '%s', try 'kaze'/'brisk'/'akaze'/'orb'/'surf'/'sift'/'brief' instead." % method) else: if method in ["mstpl", "gmstpl"]: ret = self._try_match(func, ori_image, screen, threshold=self.threshold, rgb=self.rgb, record_pos=self.record_pos, resolution=self.resolution, scale_max=self.scale_max, scale_step=self.scale_step) else: ret = self._try_match(func, image, screen, threshold=self.threshold, rgb=self.rgb) if ret: break return ret @staticmethod def _try_match(func, *args, **kwargs): G.LOGGING.debug("try match with %s" % func.__name__) try: ret = func(*args, **kwargs).find_best_result() except aircv.NoModuleError as err: G.LOGGING.warning("'surf'/'sift'/'brief' is in opencv-contrib module. You can use 'tpl'/'kaze'/'brisk'/'akaze'/'orb' in CVSTRATEGY, or reinstall opencv with the contrib module.") return None except aircv.BaseError as err: G.LOGGING.debug(repr(err)) return None else: return ret def _imread(self): return aircv.imread(self.filepath) def _find_all_template(self, image, screen): return TemplateMatching(image, screen, threshold=self.threshold, rgb=self.rgb).find_all_results() def _find_keypoint_result_in_predict_area(self, func, image, screen): if not self.record_pos: return None # calc predict area in screen image_wh, screen_resolution = aircv.get_resolution(image), aircv.get_resolution(screen) xmin, ymin, xmax, ymax = Predictor.get_predict_area(self.record_pos, image_wh, self.resolution, screen_resolution) # crop predict image from screen predict_area = aircv.crop_image(screen, (xmin, ymin, xmax, ymax)) if not predict_area.any(): return None # keypoint matching in predicted area: ret_in_area = func(image, predict_area, threshold=self.threshold, rgb=self.rgb) # calc cv ret if found if not ret_in_area: return None ret = deepcopy(ret_in_area) if "rectangle" in ret: for idx, item in enumerate(ret["rectangle"]): ret["rectangle"][idx] = (item[0] + xmin, item[1] + ymin) ret["result"] = (ret_in_area["result"][0] + xmin, ret_in_area["result"][1] + ymin) return ret def _resize_image(self, image, screen, resize_method): """模板匹配中,将输入的截图适配成 等待模板匹配的截图.""" # 未记录录制分辨率,跳过 if not self.resolution: return image screen_resolution = aircv.get_resolution(screen) # 如果分辨率一致,则不需要进行im_search的适配: if tuple(self.resolution) == tuple(screen_resolution) or resize_method is None: return image if isinstance(resize_method, types.MethodType): resize_method = resize_method.__func__ # 分辨率不一致则进行适配,默认使用cocos_min_strategy: h, w = image.shape[:2] w_re, h_re = resize_method(w, h, self.resolution, screen_resolution) # 确保w_re和h_re > 0, 至少有1个像素: w_re, h_re = max(1, w_re), max(1, h_re) # 调试代码: 输出调试信息. G.LOGGING.debug("resize: (%s, %s)->(%s, %s), resolution: %s=>%s" % ( w, h, w_re, h_re, self.resolution, screen_resolution)) # 进行图片缩放: image = cv2.resize(image, (w_re, h_re)) return image
[docs]class Predictor(object): """ this class predicts the press_point and the area to search im_search. """ DEVIATION = 100
[docs] @staticmethod def count_record_pos(pos, resolution): """计算坐标对应的中点偏移值相对于分辨率的百分比.""" _w, _h = resolution # 都按宽度缩放,针对G18的实验结论 delta_x = (pos[0] - _w * 0.5) / _w delta_y = (pos[1] - _h * 0.5) / _w delta_x = round(delta_x, 3) delta_y = round(delta_y, 3) return delta_x, delta_y
[docs] @classmethod def get_predict_point(cls, record_pos, screen_resolution): """预测缩放后的点击位置点.""" delta_x, delta_y = record_pos _w, _h = screen_resolution target_x = delta_x * _w + _w * 0.5 target_y = delta_y * _w + _h * 0.5 return target_x, target_y
[docs] @classmethod def get_predict_area(cls, record_pos, image_wh, image_resolution=(), screen_resolution=()): """Get predicted area in screen.""" x, y = cls.get_predict_point(record_pos, screen_resolution) # The prediction area should depend on the image size: if image_resolution: predict_x_radius = int(image_wh[0] * screen_resolution[0] / (2 * image_resolution[0])) + cls.DEVIATION predict_y_radius = int(image_wh[1] * screen_resolution[1] / (2 * image_resolution[1])) + cls.DEVIATION else: predict_x_radius, predict_y_radius = int(image_wh[0] / 2) + cls.DEVIATION, int(image_wh[1] / 2) + cls.DEVIATION area = (x - predict_x_radius, y - predict_y_radius, x + predict_x_radius, y + predict_y_radius) return area