mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
support images and tables extract from docx (#4619)
This commit is contained in:
parent
5893ebec55
commit
233c4150d1
|
@ -16,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||||
from core.rag.extractor.text_extractor import TextExtractor
|
from core.rag.extractor.text_extractor import TextExtractor
|
||||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
|
||||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||||
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
||||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||||
|
@ -108,7 +107,7 @@ class ExtractProcessor:
|
||||||
elif file_extension in ['.htm', '.html']:
|
elif file_extension in ['.htm', '.html']:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension in ['.docx']:
|
elif file_extension in ['.docx']:
|
||||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension == '.csv':
|
elif file_extension == '.csv':
|
||||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||||
elif file_extension == '.msg':
|
elif file_extension == '.msg':
|
||||||
|
@ -137,7 +136,7 @@ class ExtractProcessor:
|
||||||
elif file_extension in ['.htm', '.html']:
|
elif file_extension in ['.htm', '.html']:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension in ['.docx']:
|
elif file_extension in ['.docx']:
|
||||||
extractor = WordExtractor(file_path)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension == '.csv':
|
elif file_extension == '.csv':
|
||||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||||
elif file_extension == 'epub':
|
elif file_extension == 'epub':
|
||||||
|
|
|
@ -1,12 +1,20 @@
|
||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
import datetime
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import uuid
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
from core.rag.extractor.extractor_base import BaseExtractor
|
from core.rag.extractor.extractor_base import BaseExtractor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
class WordExtractor(BaseExtractor):
|
class WordExtractor(BaseExtractor):
|
||||||
|
@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor):
|
||||||
file_path: Path to the file to load.
|
file_path: Path to the file to load.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file_path: str):
|
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||||
"""Initialize with file path."""
|
"""Initialize with file path."""
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.user_id = user_id
|
||||||
|
|
||||||
if "~" in self.file_path:
|
if "~" in self.file_path:
|
||||||
self.file_path = os.path.expanduser(self.file_path)
|
self.file_path = os.path.expanduser(self.file_path)
|
||||||
|
|
||||||
|
@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor):
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
"""Load given path as single page."""
|
"""Load given path as single page."""
|
||||||
from docx import Document as docx_Document
|
content = self.parse_docx(self.file_path, 'storage')
|
||||||
|
|
||||||
document = docx_Document(self.file_path)
|
|
||||||
doc_texts = [paragraph.text for paragraph in document.paragraphs]
|
|
||||||
content = '\n'.join(doc_texts)
|
|
||||||
|
|
||||||
return [Document(
|
return [Document(
|
||||||
page_content=content,
|
page_content=content,
|
||||||
metadata={"source": self.file_path},
|
metadata={"source": self.file_path},
|
||||||
|
@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor):
|
||||||
"""Check if the url is valid."""
|
"""Check if the url is valid."""
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||||
|
|
||||||
|
def _extract_images_from_docx(self, doc, image_folder):
|
||||||
|
os.makedirs(image_folder, exist_ok=True)
|
||||||
|
image_count = 0
|
||||||
|
image_map = {}
|
||||||
|
|
||||||
|
for rel in doc.part.rels.values():
|
||||||
|
if "image" in rel.target_ref:
|
||||||
|
image_count += 1
|
||||||
|
image_ext = rel.target_ref.split('.')[-1]
|
||||||
|
# user uuid as file name
|
||||||
|
file_uuid = str(uuid.uuid4())
|
||||||
|
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_key)
|
||||||
|
|
||||||
|
storage.save(file_key, rel.target_part.blob)
|
||||||
|
# save file to db
|
||||||
|
config = current_app.config
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
storage_type=config['STORAGE_TYPE'],
|
||||||
|
key=file_key,
|
||||||
|
name=file_key,
|
||||||
|
size=0,
|
||||||
|
extension=image_ext,
|
||||||
|
mime_type=mime_type,
|
||||||
|
created_by=self.user_id,
|
||||||
|
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||||
|
used=True,
|
||||||
|
used_by=self.user_id,
|
||||||
|
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
|
||||||
|
|
||||||
|
return image_map
|
||||||
|
|
||||||
|
def _table_to_markdown(self, table):
|
||||||
|
markdown = ""
|
||||||
|
# deal with table headers
|
||||||
|
header_row = table.rows[0]
|
||||||
|
headers = [cell.text for cell in header_row.cells]
|
||||||
|
markdown += "| " + " | ".join(headers) + " |\n"
|
||||||
|
markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n"
|
||||||
|
# deal with table rows
|
||||||
|
for row in table.rows[1:]:
|
||||||
|
row_cells = [cell.text for cell in row.cells]
|
||||||
|
markdown += "| " + " | ".join(row_cells) + " |\n"
|
||||||
|
|
||||||
|
return markdown
|
||||||
|
|
||||||
|
def _parse_paragraph(self, paragraph, image_map):
|
||||||
|
paragraph_content = []
|
||||||
|
for run in paragraph.runs:
|
||||||
|
if run.element.xpath('.//a:blip'):
|
||||||
|
for blip in run.element.xpath('.//a:blip'):
|
||||||
|
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||||
|
if embed_id:
|
||||||
|
rel_target = run.part.rels[embed_id].target_ref
|
||||||
|
if rel_target in image_map:
|
||||||
|
paragraph_content.append(image_map[rel_target])
|
||||||
|
if run.text.strip():
|
||||||
|
paragraph_content.append(run.text.strip())
|
||||||
|
return ' '.join(paragraph_content) if paragraph_content else ''
|
||||||
|
|
||||||
|
def parse_docx(self, docx_path, image_folder):
|
||||||
|
doc = DocxDocument(docx_path)
|
||||||
|
os.makedirs(image_folder, exist_ok=True)
|
||||||
|
|
||||||
|
content = []
|
||||||
|
|
||||||
|
image_map = self._extract_images_from_docx(doc, image_folder)
|
||||||
|
|
||||||
|
def parse_paragraph(paragraph):
|
||||||
|
paragraph_content = []
|
||||||
|
for run in paragraph.runs:
|
||||||
|
if run.element.tag.endswith('r'):
|
||||||
|
drawing_elements = run.element.findall(
|
||||||
|
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
|
||||||
|
for drawing in drawing_elements:
|
||||||
|
blip_elements = drawing.findall(
|
||||||
|
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
|
||||||
|
for blip in blip_elements:
|
||||||
|
embed_id = blip.get(
|
||||||
|
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||||
|
if embed_id:
|
||||||
|
image_part = doc.part.related_parts.get(embed_id)
|
||||||
|
if image_part in image_map:
|
||||||
|
paragraph_content.append(image_map[image_part])
|
||||||
|
if run.text.strip():
|
||||||
|
paragraph_content.append(run.text.strip())
|
||||||
|
return ''.join(paragraph_content) if paragraph_content else ''
|
||||||
|
|
||||||
|
paragraphs = doc.paragraphs.copy()
|
||||||
|
tables = doc.tables.copy()
|
||||||
|
for element in doc.element.body:
|
||||||
|
if element.tag.endswith('p'): # paragraph
|
||||||
|
para = paragraphs.pop(0)
|
||||||
|
parsed_paragraph = parse_paragraph(para)
|
||||||
|
if parsed_paragraph:
|
||||||
|
content.append(parsed_paragraph)
|
||||||
|
elif element.tag.endswith('tbl'): # table
|
||||||
|
table = tables.pop(0)
|
||||||
|
content.append(self._table_to_markdown(table))
|
||||||
|
return '\n'.join(content)
|
||||||
|
|
||||||
|
|
|
@ -144,9 +144,9 @@ class DatasetRetrieval:
|
||||||
float('inf')))
|
float('inf')))
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
|
||||||
else:
|
else:
|
||||||
document_context_list.append(segment.content)
|
document_context_list.append(segment.get_sign_content())
|
||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
context_list = []
|
context_list = []
|
||||||
resource_number = 1
|
resource_number = 1
|
||||||
|
|
|
@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
float('inf')))
|
float('inf')))
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
|
||||||
else:
|
else:
|
||||||
document_context_list.append(segment.content)
|
document_context_list.append(segment.get_sign_content())
|
||||||
if self.return_resource:
|
if self.return_resource:
|
||||||
context_list = []
|
context_list = []
|
||||||
resource_number = 1
|
resource_number = 1
|
||||||
|
|
|
@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
float('inf')))
|
float('inf')))
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
|
||||||
else:
|
else:
|
||||||
document_context_list.append(segment.content)
|
document_context_list.append(segment.get_sign_content())
|
||||||
if self.return_resource:
|
if self.return_resource:
|
||||||
context_list = []
|
context_list = []
|
||||||
resource_number = 1
|
resource_number = 1
|
||||||
|
|
|
@ -191,9 +191,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||||
'title': document.name
|
'title': document.name
|
||||||
}
|
}
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
|
||||||
else:
|
else:
|
||||||
source['content'] = segment.content
|
source['content'] = segment.get_sign_content()
|
||||||
context_list.append(source)
|
context_list.append(source)
|
||||||
resource_number += 1
|
resource_number += 1
|
||||||
return context_list
|
return context_list
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import time
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
@ -414,6 +421,26 @@ class DocumentSegment(db.Model):
|
||||||
DocumentSegment.position == self.position + 1
|
DocumentSegment.position == self.position + 1
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
def get_sign_content(self):
|
||||||
|
pattern = r"/files/([a-f0-9\-]+)/image-preview"
|
||||||
|
text = self.content
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
upload_file_id = match.group(1)
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
replacement = r"\g<0>?{params}".format(params=params)
|
||||||
|
text = re.sub(pattern, replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AppDatasetJoin(db.Model):
|
class AppDatasetJoin(db.Model):
|
||||||
__tablename__ = 'app_dataset_joins'
|
__tablename__ = 'app_dataset_joins'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user