From 83e2570d3c8da9920d66a00c4bdf5650fe1b3336 Mon Sep 17 00:00:00 2001 From: Akshay Nair Date: Wed, 25 Dec 2024 22:33:03 +0530 Subject: Parallel ocr evaluation for sections of screen + many refactorings --- chelleport.cabal | 5 +- cpp/image.cpp | 41 ++++++++++++++++ cpp/libchelleport.cpp | 122 ++++++++++++++++-------------------------------- cpp/recognizer.cpp | 65 ++++++++++++++++++++++++++ include/image.h | 19 ++++++++ include/libchelleport.h | 52 ++++++--------------- include/recognizer.h | 50 ++++++++++++++++++++ src/Chelleport/OCR.hs | 1 - 8 files changed, 234 insertions(+), 121 deletions(-) create mode 100644 cpp/image.cpp create mode 100644 cpp/recognizer.cpp create mode 100644 include/image.h create mode 100644 include/recognizer.h diff --git a/chelleport.cabal b/chelleport.cabal index 2d03b63..bf269ba 100644 --- a/chelleport.cabal +++ b/chelleport.cabal @@ -50,7 +50,10 @@ common warnings common extension extra-libraries: stdc++ Xtst X11 tesseract leptonica include-dirs: include - c-sources: cpp/libchelleport.cpp + c-sources: + cpp/libchelleport.cpp + cpp/recognizer.cpp + cpp/image.cpp cxx-options: -O3 -ffast-math -march=native extra-source-files: cpp/*.cpp diff --git a/cpp/image.cpp b/cpp/image.cpp new file mode 100644 index 0000000..63f0a08 --- /dev/null +++ b/cpp/image.cpp @@ -0,0 +1,41 @@ +#include +#include +#include + +#include "../include/image.h" + +namespace image { +void preprocessImage(Pix **image) { + Pix *temp; + + // Scale + if (scaleFactor != 1) { + INLINE_IMAGE_PROC(pixScale(*image, scaleFactor, scaleFactor)); + } + + // Grayscale + if (pixGetDepth(*image) > 8) { + INLINE_IMAGE_PROC(pixConvertRGBToGray( + *image, grayscaleWeightRed, grayscaleWeightGreen, grayscaleWeightBlue)); + } + + // Contrast + pixContrastTRC(*image, *image, contrast); + + // Sharpness + // INLINE_IMAGE_PROC(pixUnsharpMaskingGrayFast(*image, 1, sharpness, 1)); + INLINE_IMAGE_PROC(pixUnsharpMasking(*image, 1, sharpness)); +} + +Pix *loadImage(const char *imagePath) { + Pix *image = pixRead(imagePath); + if (!image) { + std::cerr << "Could not load image " << imagePath << std::endl; + return nullptr; + } + + preprocessImage(&image); + + return image; +} +} // namespace image diff --git a/cpp/libchelleport.cpp b/cpp/libchelleport.cpp index 923df6c..67abb06 100644 --- a/cpp/libchelleport.cpp +++ b/cpp/libchelleport.cpp @@ -1,19 +1,20 @@ -#include -#include -#include -#include -#include #include +#include +#include #include +#include #include +#include "../include/image.h" #include "../include/libchelleport.h" +#include "../include/recognizer.h" +extern "C" { OCRMatch *findWordCoordinates(const char *image_path, int *size) { std::vector matches; - MEASURE("OCR", { matches = extractTextCoordinates(image_path); }); + MEASURE("OCR", { matches = extractTextMatches(image_path); }); - std::cout << "Word count: " << matches.size() << std::endl; + std::cout << "Match count: " << matches.size() << std::endl; static OCRMatch *ptr = new OCRMatch[matches.size()]; std::copy(matches.begin(), matches.end(), ptr); @@ -21,101 +22,60 @@ OCRMatch *findWordCoordinates(const char *image_path, int *size) { *size = matches.size(); return ptr; } +} -std::vector extractTextCoordinates(const char *imagePath) { +std::vector extractTextMatches(const char *imagePath) { std::vector results; - auto tesseract = initializeTesseract(); - if (tesseract == nullptr) - return results; - - Pix *image = loadImage(imagePath); + Pix *image = image::loadImage(imagePath); if (image == nullptr) return results; // printf("imagePath: %s\n", imagePath); // pixWrite(imagePath, image, IFF_JFIF_JPEG); - tesseract->SetImage(image); - tesseract->Recognize(0); - - tesseract::ResultIterator *iterator = tesseract->GetIterator(); - auto level = RESULT_ITER_MODE; - - if (iterator != 0) { - do { - if (iterator->Confidence(level) > CONFIDENCE_THRESHOLD) { - const char *word = iterator->GetUTF8Text(level); - - if (word != nullptr && strlen(word) >= MIN_CHARACTER_COUNT) { - int x1, y1, x2, y2; - iterator->BoundingBox(level, &x1, &y1, &x2, &y2); - OCRMatch match({(int)(x1 / scaleFactor), (int)(y1 / scaleFactor), - (int)(x2 / scaleFactor), (int)(y2 / scaleFactor), - word}); - results.push_back(match); - } - } - } while (iterator->Next(level)); - } + int width = pixGetWidth(image); + int height = pixGetHeight(image); - delete iterator; - tesseract->End(); - delete tesseract; - pixDestroy(&image); + std::vector> recognizers; + recognizers.push_back( + std::make_unique(0, 0, width / 2, height / 2)); - return results; -} + recognizers.push_back( + std::make_unique(width / 2, 0, width / 2, height / 2)); -inline tesseract::TessBaseAPI *initializeTesseract() { - auto *tesseract = new tesseract::TessBaseAPI(); - tesseract->SetPageSegMode(tesseract::PSM_AUTO); + recognizers.push_back( + std::make_unique(0, height / 2, width / 2, height / 2)); - if (tesseract->Init(nullptr, "eng", tesseract::OEM_LSTM_ONLY)) { - std::cerr << "Could not initialize tesseract." << std::endl; - return nullptr; - } + recognizers.push_back(std::make_unique(width / 2, height / 2, + width / 2, height / 2)); - return tesseract; + return runRecognizers(recognizers, image); } -inline Pix *loadImage(const char *imagePath) { - Pix *image = pixRead(imagePath); - if (!image) { - std::cerr << "Could not load image " << imagePath << std::endl; - return nullptr; - } - - preprocessImage(&image); - - return image; -} +std::vector +runRecognizers(std::vector> &recognizers, + Pix *image) { + std::vector results; + std::shared_ptr sharedImage(image, [](Pix *p) { pixDestroy(&p); }); -void preprocessImage(Pix **image) { - Pix *temp; + std::vector workers; + workers.reserve(recognizers.size()); - // Scale - if (scaleFactor != 1) { - INLINE_IMAGE_PROC(pixScale(*image, scaleFactor, scaleFactor)); + for (auto &ext : recognizers) { + workers.push_back(std::thread( + [&ext, &sharedImage]() { ext->recognize(sharedImage.get()); })); } - // Grayscale - if (pixGetDepth(*image) > 8) { - INLINE_IMAGE_PROC(pixConvertRGBToGray( - *image, grayscaleWeightRed, grayscaleWeightGreen, grayscaleWeightBlue)); + for (std::thread &t : workers) { + if (t.joinable()) + t.join(); } - // Contrast - pixContrastTRC(*image, *image, contrast); - - // Sharpness - // INLINE_IMAGE_PROC(pixUnsharpMaskingGrayFast(*image, 1, sharpness, 1)); - INLINE_IMAGE_PROC(pixUnsharpMasking(*image, 1, sharpness)); -} + for (auto &ext : recognizers) { + for (auto &match : ext->getResults()) + results.push_back(match); + } -void printMatch(const OCRMatch &match) { - std::cout << "Text: " << match.text << "; Position: (" << match.startX << "," - << match.startY << ") -> (" << match.endX << "," << match.endY - << ")" << std::endl - << std::endl; + return results; } diff --git a/cpp/recognizer.cpp b/cpp/recognizer.cpp new file mode 100644 index 0000000..6f19322 --- /dev/null +++ b/cpp/recognizer.cpp @@ -0,0 +1,65 @@ +#include +#include + +#include "../include/recognizer.h" + +void Recognizer::initializeTesseract() { + tesseract = new tesseract::TessBaseAPI(); + tesseract->SetPageSegMode(tesseract::PSM_AUTO); + + if (tesseract->Init(nullptr, "eng", tesseract::OEM_LSTM_ONLY)) + fail("Could not initialize tesseract."); +} + +void Recognizer::recognize(Pix *image) { + if (failed) + return; + + tesseract->SetImage(image); + tesseract->SetRectangle(x, y, width, height); + if (tesseract->Recognize(0) != 0) + fail("tesseract recognize failed"); +} + +std::vector Recognizer::getResults() { + std::vector results; + + if (failed) + return results; + + tesseract::ResultIterator *iterator = tesseract->GetIterator(); + if (iterator == 0) + return results; + + do { + auto match = fetchMatch(iterator); + if (match != nullptr) + results.push_back(*match); + } while (iterator->Next(ITER_LEVEL)); + + delete iterator; + + return results; +} + +OCRMatch *Recognizer::fetchMatch(tesseract::ResultIterator *iterator) { + if (iterator->Confidence(ITER_LEVEL) < CONFIDENCE_THRESHOLD) + return nullptr; + + const char *word = iterator->GetUTF8Text(ITER_LEVEL); + + if (word == nullptr || strlen(word) < MIN_CHARACTER_COUNT) + return nullptr; + + int x1, y1, x2, y2; + iterator->BoundingBox(ITER_LEVEL, &x1, &y1, &x2, &y2); + + return new OCRMatch( + {(int)(x1 / image::scaleFactor), (int)(y1 / image::scaleFactor), + (int)(x2 / image::scaleFactor), (int)(y2 / image::scaleFactor), word}); +} + +void Recognizer::fail(const char *msg) { + this->failed = true; + std::cerr << msg << std::endl; +} diff --git a/include/image.h b/include/image.h new file mode 100644 index 0000000..ecbc255 --- /dev/null +++ b/include/image.h @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace image { +// Preprocessing configuration +static const float contrast = 0.3; +static const float sharpness = 0.7; +static const float scaleFactor = 1; +static const float grayscaleWeightRed = 0.114; +static const float grayscaleWeightGreen = 0.587; +static const float grayscaleWeightBlue = 0.299; + +Pix *loadImage(const char *imagePath); +} // namespace image + +#define INLINE_IMAGE_PROC(process) \ + temp = process; \ + pixDestroy(image); \ + *image = temp; diff --git a/include/libchelleport.h b/include/libchelleport.h index e6a074d..b69466e 100644 --- a/include/libchelleport.h +++ b/include/libchelleport.h @@ -1,46 +1,12 @@ +#pragma once #include +#include #include +#include #include #include -// NOTE: Remember to update size and alignment in ocr hs module on change -struct OCRMatch { - int startX, startY; - int endX, endY; - const char *text; -}; - -// OCR configuration -#define CONFIDENCE_THRESHOLD 25. -#define MIN_CHARACTER_COUNT 3 -const tesseract::PageIteratorLevel RESULT_ITER_MODE = tesseract::RIL_WORD; - -// Preprocessing configuration -const float contrast = 0.3; -const float sharpness = 0.7; -const float scaleFactor = 1; -const float grayscaleWeightRed = 0.114; -const float grayscaleWeightGreen = 0.587; -const float grayscaleWeightBlue = 0.299; - -extern "C" { -OCRMatch *findWordCoordinates(const char *image_path, /* returns */ int *size); -} - -tesseract::TessBaseAPI *initializeTesseract(); - -Pix *loadImage(const char *imagePath); - -std::vector extractTextCoordinates(const char *imagePath); - -void printMatch(const OCRMatch &match); - -void preprocessImage(Pix **image); - -#define INLINE_IMAGE_PROC(process) \ - temp = process; \ - pixDestroy(image); \ - *image = temp; +#include "./recognizer.h" #define MEASURE(label, stmts) \ auto start = std::chrono::high_resolution_clock::now(); \ @@ -49,3 +15,13 @@ void preprocessImage(Pix **image); auto duration = \ std::chrono::duration_cast(end - start); \ std::cout << label << ": " << duration.count() / 1000.0 << " ms" << std::endl; + +extern "C" { +OCRMatch *findWordCoordinates(const char *image_path, /* returns */ int *size); +} + +std::vector extractTextMatches(const char *imagePath); + +std::vector +runRecognizers(std::vector> &recognizers, + Pix *image); diff --git a/include/recognizer.h b/include/recognizer.h new file mode 100644 index 0000000..57747bb --- /dev/null +++ b/include/recognizer.h @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include + +#include "./image.h" + +// OCR configuration +#define CONFIDENCE_THRESHOLD 25. +#define MIN_CHARACTER_COUNT 3 +const tesseract::PageIteratorLevel ITER_LEVEL = tesseract::RIL_WORD; + +// NOTE: Remember to update size and alignment in ocr hs module on change +struct OCRMatch { + int startX, startY; + int endX, endY; + const char *text; +}; + +class Recognizer { + tesseract::TessBaseAPI *tesseract; + int x, y, width, height; + bool failed = false; + +public: + Recognizer(int x, int y, int width, int height) + : x(x), y(y), width(width), height(height) { + initializeTesseract(); + } + + ~Recognizer() { tesseract->End(); } + + void fail(const char *msg); + + void recognize(Pix *image); + + OCRMatch *fetchMatch(tesseract::ResultIterator *iterator); + + std::vector getResults(); + +private: + void initializeTesseract(); +}; + +inline void printMatch(const OCRMatch &match) { + std::cout << "Text: " << match.text << "; Position: (" << match.startX << "," + << match.startY << ") -> (" << match.endX << "," << match.endY + << ")" << std::endl + << std::endl; +} diff --git a/src/Chelleport/OCR.hs b/src/Chelleport/OCR.hs index 5ee331c..3cd83ae 100644 --- a/src/Chelleport/OCR.hs +++ b/src/Chelleport/OCR.hs @@ -1,7 +1,6 @@ module Chelleport.OCR (MonadOCR (..)) where import Chelleport.Types -import Chelleport.Utils (benchmark) import Control.Concurrent (threadDelay) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.RWS (MonadReader (ask)) -- cgit v1.3.1