support images and tables extract from docx (#4619)

This commit is contained in:
Jyong 2024-05-23 18:05:23 +08:00 committed by GitHub
parent 5893ebec55
commit 233c4150d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 163 additions and 23 deletions

View File

@ -428,7 +428,7 @@ class IndexingRunner:
chunk_size=segmentation["max_tokens"], chunk_size=segmentation["max_tokens"],
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )
else: else:
@ -436,7 +436,7 @@ class IndexingRunner:
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "", ".", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )

View File

@ -16,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
@ -108,7 +107,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']: elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']: elif file_extension in ['.docx']:
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv': elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True) extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg': elif file_extension == '.msg':
@ -137,7 +136,7 @@ class ExtractProcessor:
elif file_extension in ['.htm', '.html']: elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']: elif file_extension in ['.docx']:
extractor = WordExtractor(file_path) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv': elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True) extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == 'epub': elif file_extension == 'epub':

View File

@ -1,12 +1,20 @@
"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""
import datetime
import mimetypes
import os import os
import tempfile import tempfile
import uuid
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
from docx import Document as DocxDocument
from flask import current_app
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class WordExtractor(BaseExtractor): class WordExtractor(BaseExtractor):
@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str): def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path.""" """Initialize with file path."""
self.file_path = file_path self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
if "~" in self.file_path: if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path) self.file_path = os.path.expanduser(self.file_path)
@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load given path as single page.""" """Load given path as single page."""
from docx import Document as docx_Document content = self.parse_docx(self.file_path, 'storage')
document = docx_Document(self.file_path)
doc_texts = [paragraph.text for paragraph in document.paragraphs]
content = '\n'.join(doc_texts)
return [Document( return [Document(
page_content=content, page_content=content,
metadata={"source": self.file_path}, metadata={"source": self.file_path},
@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor):
"""Check if the url is valid.""" """Check if the url is valid."""
parsed = urlparse(url) parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme) return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
os.makedirs(image_folder, exist_ok=True)
image_count = 0
image_map = {}
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
image_count += 1
image_ext = rel.target_ref.split('.')[-1]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file_key,
size=0,
extension=image_ext,
mime_type=mime_type,
created_by=self.user_id,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
used=True,
used_by=self.user_id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
return image_map
def _table_to_markdown(self, table):
markdown = ""
# deal with table headers
header_row = table.rows[0]
headers = [cell.text for cell in header_row.cells]
markdown += "| " + " | ".join(headers) + " |\n"
markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n"
# deal with table rows
for row in table.rows[1:]:
row_cells = [cell.text for cell in row.cells]
markdown += "| " + " | ".join(row_cells) + " |\n"
return markdown
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath('.//a:blip'):
for blip in run.element.xpath('.//a:blip'):
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ' '.join(paragraph_content) if paragraph_content else ''
def parse_docx(self, docx_path, image_folder):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []
image_map = self._extract_images_from_docx(doc, image_folder)
def parse_paragraph(paragraph):
paragraph_content = []
for run in paragraph.runs:
if run.element.tag.endswith('r'):
drawing_elements = run.element.findall(
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
for drawing in drawing_elements:
blip_elements = drawing.findall(
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
for blip in blip_elements:
embed_id = blip.get(
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ''.join(paragraph_content) if paragraph_content else ''
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
for element in doc.element.body:
if element.tag.endswith('p'): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph:
content.append(parsed_paragraph)
elif element.tag.endswith('tbl'): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table))
return '\n'.join(content)

View File

@ -57,7 +57,7 @@ class BaseIndexProcessor(ABC):
chunk_size=segmentation["max_tokens"], chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0), chunk_overlap=segmentation.get('chunk_overlap', 0),
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )
else: else:
@ -65,7 +65,7 @@ class BaseIndexProcessor(ABC):
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "", ".", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )

View File

@ -144,9 +144,9 @@ class DatasetRetrieval:
float('inf'))) float('inf')))
for segment in sorted_segments: for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}') document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else: else:
document_context_list.append(segment.content) document_context_list.append(segment.get_sign_content())
if show_retrieve_source: if show_retrieve_source:
context_list = [] context_list = []
resource_number = 1 resource_number = 1

View File

@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
documents.append(new_doc) documents.append(new_doc)
return documents return documents
def split_documents(self, documents: Iterable[Document] ) -> list[Document]: def split_documents(self, documents: Iterable[Document]) -> list[Document]:
"""Split documents.""" """Split documents."""
texts, metadatas = [], [] texts, metadatas = [], []
for doc in documents: for doc in documents:

View File

@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
float('inf'))) float('inf')))
for segment in sorted_segments: for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}') document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else: else:
document_context_list.append(segment.content) document_context_list.append(segment.get_sign_content())
if self.return_resource: if self.return_resource:
context_list = [] context_list = []
resource_number = 1 resource_number = 1

View File

@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
float('inf'))) float('inf')))
for segment in sorted_segments: for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}') document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
else: else:
document_context_list.append(segment.content) document_context_list.append(segment.get_sign_content())
if self.return_resource: if self.return_resource:
context_list = [] context_list = []
resource_number = 1 resource_number = 1

View File

@ -191,9 +191,9 @@ class KnowledgeRetrievalNode(BaseNode):
'title': document.name 'title': document.name
} }
if segment.answer: if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
else: else:
source['content'] = segment.content source['content'] = segment.get_sign_content()
context_list.append(source) context_list.append(source)
resource_number += 1 resource_number += 1
return context_list return context_list

View File

@ -1,8 +1,15 @@
import base64
import hashlib
import hmac
import json import json
import logging import logging
import os
import pickle import pickle
import re
import time
from json import JSONDecodeError from json import JSONDecodeError
from flask import current_app
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
@ -414,6 +421,26 @@ class DocumentSegment(db.Model):
DocumentSegment.position == self.position + 1 DocumentSegment.position == self.position + 1
).first() ).first()
def get_sign_content(self):
pattern = r"/files/([a-f0-9\-]+)/image-preview"
text = self.content
match = re.search(pattern, text)
if match:
upload_file_id = match.group(1)
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
replacement = r"\g<0>?{params}".format(params=params)
text = re.sub(pattern, replacement, text)
return text
class AppDatasetJoin(db.Model): class AppDatasetJoin(db.Model):
__tablename__ = 'app_dataset_joins' __tablename__ = 'app_dataset_joins'