support images and tables extract from docx (#4619)

This commit is contained in:
Jyong 2024-05-23 18:05:23 +08:00 committed by GitHub
parent 5893ebec55
commit 233c4150d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 163 additions and 23 deletions

View File

@ -16,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
@ -108,7 +107,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
@ -137,7 +136,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = WordExtractor(file_path)
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == 'epub':

View File

@ -1,12 +1,20 @@
"""Abstract interface for document loader implementations."""
import datetime
import mimetypes
import os
import tempfile
import uuid
from urllib.parse import urlparse
import requests
from docx import Document as DocxDocument
from flask import current_app
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class WordExtractor(BaseExtractor):
@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
from docx import Document as docx_Document
document = docx_Document(self.file_path)
doc_texts = [paragraph.text for paragraph in document.paragraphs]
content = '\n'.join(doc_texts)
content = self.parse_docx(self.file_path, 'storage')
return [Document(
page_content=content,
metadata={"source": self.file_path},
@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor):
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
os.makedirs(image_folder, exist_ok=True)
image_count = 0
image_map = {}
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
image_count += 1
image_ext = rel.target_ref.split('.')[-1]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file_key,
size=0,
extension=image_ext,
mime_type=mime_type,
created_by=self.user_id,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
used=True,
used_by=self.user_id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
return image_map
def _table_to_markdown(self, table):
markdown = ""
# deal with table headers
header_row = table.rows[0]
headers = [cell.text for cell in header_row.cells]
markdown += "| " + " | ".join(headers) + " |\n"
markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n"
# deal with table rows
for row in table.rows[1:]:
row_cells = [cell.text for cell in row.cells]
markdown += "| " + " | ".join(row_cells) + " |\n"
return markdown
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath('.//a:blip'):
for blip in run.element.xpath('.//a:blip'):
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ' '.join(paragraph_content) if paragraph_content else ''
def parse_docx(self, docx_path, image_folder):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []
image_map = self._extract_images_from_docx(doc, image_folder)
def parse_paragraph(paragraph):
paragraph_content = []
for run in paragraph.runs:
if run.element.tag.endswith('r'):
drawing_elements = run.element.findall(
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
for drawing in drawing_elements:
blip_elements = drawing.findall(
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
for blip in blip_elements:
embed_id = blip.get(
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ''.join(paragraph_content) if paragraph_content else ''
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
for element in doc.element.body:
if element.tag.endswith('p'): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph:
content.append(parsed_paragraph)
elif element.tag.endswith('tbl'): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table))
return '\n'.join(content)

View File

@ -144,9 +144,9 @@ class DatasetRetrieval:
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
document_context_list.append(segment.get_sign_content())
if show_retrieve_source:
context_list = []
resource_number = 1

View File

@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1

View File

@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1

View File

@ -191,9 +191,9 @@ class KnowledgeRetrievalNode(BaseNode):
'title': document.name
}
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
source['content'] = segment.get_sign_content()
context_list.append(source)
resource_number += 1
return context_list

View File

@ -1,8 +1,15 @@
import base64
import hashlib
import hmac
import json
import logging
import os
import pickle
import re
import time
from json import JSONDecodeError
from flask import current_app
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
@ -414,6 +421,26 @@ class DocumentSegment(db.Model):
DocumentSegment.position == self.position + 1
).first()
def get_sign_content(self):
pattern = r"/files/([a-f0-9\-]+)/image-preview"
text = self.content
match = re.search(pattern, text)
if match:
upload_file_id = match.group(1)
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
replacement = r"\g<0>?{params}".format(params=params)
text = re.sub(pattern, replacement, text)
return text
class AppDatasetJoin(db.Model):
__tablename__ = 'app_dataset_joins'