chore(api): Introduce Ruff Formatter. (#7291)

This commit is contained in:
-LAN- 2024-08-15 12:54:05 +08:00 committed by GitHub
parent 8f16165f92
commit 3571292fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 1315 additions and 1335 deletions

View File

@ -45,6 +45,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

View File

@ -1,6 +1,6 @@
import os
if os.environ.get("DEBUG", "false").lower() != 'true':
if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey
monkey.patch_all()
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
if os.name == "nt":
os.system('tzutil /s "UTC"')
else:
os.environ['TZ'] = 'UTC'
os.environ["TZ"] = "UTC"
time.tzset()
@ -70,13 +70,14 @@ class DifyApp(Flask):
# -------------
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ''
os.environ[key] = ""
return dify_app
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config['SECRET_KEY']
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get('LOG_FILE')
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
@ -111,23 +112,24 @@ def create_app() -> Flask:
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5
backupCount=5,
),
logging.StreamHandler(sys.stdout)
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get('LOG_LEVEL'),
format=app.config.get('LOG_FORMAT'),
datefmt=app.config.get('LOG_DATEFORMAT'),
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True
force=True,
)
log_tz = app.config.get('LOG_TZ')
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
@ -162,24 +164,24 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ['console', 'inner_api']:
if request.blueprint not in ["console", "inner_api"]:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get('_token')
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
raise Unauthorized("Invalid Authorization token.")
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
user_id = decoded.get("user_id")
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account:
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(json.dumps({
'code': 'unauthorized',
'message': "Unauthorized."
}), status=401, content_type="application/json")
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
# register blueprint routers
@ -204,38 +207,36 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
CORS(web_bp,
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
CORS(console_app_bp,
resources={
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
@ -245,29 +246,29 @@ def register_blueprints(app):
app = create_app()
celery = app.extensions["celery"]
if app.config.get('TESTING'):
if app.config.get("TESTING"):
print("App is running in TESTING mode")
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
return response
@app.route('/health')
@app.route("/health")
def health():
return Response(json.dumps({
'pid': os.getpid(),
'status': 'ok',
'version': app.config['CURRENT_VERSION']
}), status=200, content_type="application/json")
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
status=200,
content_type="application/json",
)
@app.route('/threads')
@app.route("/threads")
def threads():
num_threads = threading.active_count()
threads = threading.enumerate()
@ -278,32 +279,34 @@ def threads():
thread_id = thread.ident
is_alive = thread.is_alive()
thread_list.append({
'name': thread_name,
'id': thread_id,
'is_alive': is_alive
})
thread_list.append(
{
"name": thread_name,
"id": thread_id,
"is_alive": is_alive,
}
)
return {
'pid': os.getpid(),
'thread_num': num_threads,
'threads': thread_list
"pid": os.getpid(),
"thread_num": num_threads,
"threads": thread_list,
}
@app.route('/db-pool-stat')
@app.route("/db-pool-stat")
def pool_stat():
engine = db.engine
return {
'pid': os.getpid(),
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.')
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
@click.option('--new-password', prompt=True, help='the new password.')
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
return
# generate password salt
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command('reset-email', help='Reset the account email.')
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
@click.option('--new-email', prompt=True, help='the new email.')
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
'After the reset, all LLM credentials will become invalid, '
'requiring re-entry.'
'Only support SELF_HOSTED mode.')
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
' this operation cannot be rolled back!', fg='red'))
@click.command(
"reset-encrypt-key-pair",
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
"After the reset, all LLM credentials will become invalid, "
"requiring re-entry."
"Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
"""
Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
click.echo(
click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@click.command('vdb-migrate', help='migrate vector db.')
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
if scope in ["annotation", "all"]:
migrate_annotation_vector_database()
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
apps = (
db.session.query(App)
.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
click.echo("App annotation setting is disabled: {}".format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
collection_binding_id=dataset_collection_binding.id,
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
click.style(
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
click.echo(f"Successfully migrated app annotation {app.id}.")
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style(
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
datasets = (
db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
collection_name = ""
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'elasticsearch',
"vector_store": {"class_prefix": index_name}
}
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
click.style(
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
)
.all()
)
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
for segment in segments:
document = Document(
@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
},
)
documents.append(document)
@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
click.echo(f"Successfully migrated dataset {dataset.id}.")
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style('Start convert to agent apps.', fg='green'))
click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = []
@ -466,7 +480,7 @@ def convert_to_agent_apps():
break
for app in apps:
click.echo('Converting app: {}'.format(app.id))
click.echo("Converting app: {}".format(app.id))
try:
app.mode = AppMode.AGENT_CHAT.value
@ -478,137 +492,139 @@ def convert_to_agent_apps():
)
db.session.commit()
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
except Exception as e:
click.echo(
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str):
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError('Qdrant url is required.')
raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index
client.create_payload_index(binding.collection_name, field,
field_schema=PayloadSchemaType.KEYWORD)
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e:
click.echo(click.style('Failed to create qdrant client.', fg='red'))
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
@click.command('create-tenant', help='Create account and tenant.')
@click.option('--email', prompt=True, help='The email address of the tenant account.')
@click.option('--language', prompt=True, help='Account language, default: en-US.')
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style('Sorry, email is required.', fg='red'))
click.echo(click.style("Sorry, email is required.", fg="red"))
return
# Create account
email = email.strip()
if '@' not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red'))
if "@" not in email:
click.echo(click.style("Sorry, invalid email address.", fg="red"))
return
account_name = email.split('@')[0]
account_name = email.split("@")[0]
if language not in languages:
language = 'en-US'
language = "en-US"
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language
)
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(click.style('Congratulations! Account and tenant created.\n'
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
click.echo(
click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command('upgrade-db', help='upgrade the database')
@click.command("upgrade-db", help="upgrade the database")
def upgrade_db():
click.echo('Preparing database migration...')
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style('Start database migration.', fg='green'))
click.echo(click.style("Start database migration.", fg="green"))
# run db migration
import flask_migrate
flask_migrate.upgrade()
click.echo(click.style('Database migration successful!', fg='green'))
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f'Database migration failed, error: {e}')
logging.exception(f"Database migration failed, error: {e}")
finally:
lock.release()
else:
click.echo('Database migration skipped')
click.echo("Database migration skipped")
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = []
while True:
@ -639,15 +655,14 @@ where sites.id is null limit 1000"""
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
logging.exception(f'Fix app related site missing issue failed, error: {e}')
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue
if not processed_count:
break
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
def register_commands(app):

View File

@ -1 +1 @@
HIDDEN_VALUE = '[__HIDDEN__]'
HIDDEN_VALUE = "[__HIDDEN__]"

View File

@ -1,22 +1,22 @@
language_timezone_mapping = {
'en-US': 'America/New_York',
'zh-Hans': 'Asia/Shanghai',
'zh-Hant': 'Asia/Taipei',
'pt-BR': 'America/Sao_Paulo',
'es-ES': 'Europe/Madrid',
'fr-FR': 'Europe/Paris',
'de-DE': 'Europe/Berlin',
'ja-JP': 'Asia/Tokyo',
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata',
'tr-TR': 'Europe/Istanbul',
'fa-IR': 'Asia/Tehran',
"en-US": "America/New_York",
"zh-Hans": "Asia/Shanghai",
"zh-Hant": "Asia/Taipei",
"pt-BR": "America/Sao_Paulo",
"es-ES": "Europe/Madrid",
"fr-FR": "Europe/Paris",
"de-DE": "Europe/Berlin",
"ja-JP": "Asia/Tokyo",
"ko-KR": "Asia/Seoul",
"ru-RU": "Europe/Moscow",
"it-IT": "Europe/Rome",
"uk-UA": "Europe/Kyiv",
"vi-VN": "Asia/Ho_Chi_Minh",
"ro-RO": "Europe/Bucharest",
"pl-PL": "Europe/Warsaw",
"hi-IN": "Asia/Kolkata",
"tr-TR": "Europe/Istanbul",
"fa-IR": "Asia/Tehran",
}
languages = list(language_timezone_mapping.keys())
@ -26,6 +26,5 @@ def supported_language(lang):
if lang in languages:
return lang
error = ('{lang} is not a valid language.'
.format(lang=lang))
error = "{lang} is not a valid language.".format(lang=lang)
raise ValueError(error)

View File

@ -5,82 +5,79 @@ from models.model import AppMode
default_app_templates = {
# workflow default mode
AppMode.WORKFLOW: {
'app': {
'mode': AppMode.WORKFLOW.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.WORKFLOW.value,
"enable_site": True,
"enable_api": True,
}
},
# completion default mode
AppMode.COMPLETION: {
'app': {
'mode': AppMode.COMPLETION.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.COMPLETION.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
"completion_params": {},
},
'user_input_form': json.dumps([
"user_input_form": json.dumps(
[
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
"default": "",
},
},
]
),
"pre_prompt": "{{query}}",
},
},
# chat default mode
AppMode.CHAT: {
'app': {
'mode': AppMode.CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
"completion_params": {},
},
},
},
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
'app': {
'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True,
'enable_api': True
}
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"enable_site": True,
"enable_api": True,
},
},
# agent-chat default mode
AppMode.AGENT_CHAT: {
'app': {
'mode': AppMode.AGENT_CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.AGENT_CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
}
"completion_params": {},
},
},
},
}

View File

@ -2,6 +2,6 @@ from contextvars import ContextVar
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar('tenant_id')
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

View File

@ -1,13 +1,13 @@
from blinker import signal
# sender: app
app_was_created = signal('app-was-created')
app_was_created = signal("app-was-created")
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal('app-model-config-was-updated')
app_model_config_was_updated = signal("app-model-config-was-updated")
# sender: app, kwargs: published_workflow
app_published_workflow_was_updated = signal('app-published-workflow-was-updated')
app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
# sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced')
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: dataset
dataset_was_deleted = signal('dataset-was-deleted')
dataset_was_deleted = signal("dataset-was-deleted")

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: document
document_was_deleted = signal('document-was-deleted')
document_was_deleted = signal("document-was-deleted")

View File

@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender, **kwargs):
dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,
dataset.indexing_technique,
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
)

View File

@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task
@document_was_deleted.connect
def handle(sender, **kwargs):
document_id = sender
dataset_id = kwargs.get('dataset_id')
doc_form = kwargs.get('doc_form')
file_id = kwargs.get('file_id')
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -14,21 +14,25 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get('document_ids', None)
document_ids = kwargs.get("document_ids", None)
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = db.session.query(Document).filter(
document = (
db.session.query(Document)
.filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
Document.dataset_id == dataset_id,
)
.first()
)
if not document:
raise NotFound('Document not found')
raise NotFound("Document not found")
document.indexing_status = 'parsing'
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
@ -38,8 +42,8 @@ def handle(sender, **kwargs):
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -10,7 +10,7 @@ def handle(sender, **kwargs):
installed_app = InstalledApp(
tenant_id=app.tenant_id,
app_id=app.id,
app_owner_tenant_id=app.tenant_id
app_owner_tenant_id=app.tenant_id,
)
db.session.add(installed_app)
db.session.commit()

View File

@ -7,15 +7,15 @@ from models.model import Site
def handle(sender, **kwargs):
"""Create site record when an app is created."""
app = sender
account = kwargs.get('account')
account = kwargs.get("account")
site = Site(
app_id=app.id,
title=app.name,
icon=app.icon,
icon_background=app.icon_background,
default_language=account.interface_language,
customize_token_strategy='not_allow',
code=Site.generate_code(16)
customize_token_strategy="not_allow",
code=Site.generate_code(16),
)
db.session.add(site)

View File

@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get('application_generate_entity')
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
@ -39,7 +39,7 @@ def handle(sender, **kwargs):
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_config.model:
if "gpt-4" in model_config.model:
used_quota = 20
else:
used_quota = 1
@ -50,6 +50,6 @@ def handle(sender, **kwargs):
Provider.provider_name == model_config.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
Provider.quota_limit > Provider.quota_used,
).update({"quota_used": Provider.quota_used + used_quota})
db.session.commit()

View File

@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced
@app_draft_workflow_was_synced.connect
def handle(sender, **kwargs):
app = sender
for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []):
if node_data.get('data', {}).get('type') == NodeType.TOOL.value:
for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
@ -23,7 +23,7 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
)
manager.delete_tool_parameters_cache()
except:

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: document
document_index_created = signal('document-index-created')
document_index_created = signal("document-index-created")

View File

@ -7,13 +7,11 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect
def handle(sender, **kwargs):
app = sender
app_model_config = kwargs.get('app_model_config')
app_model_config = kwargs.get("app_model_config")
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id
).all()
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
if not app_dataset_joins:
@ -29,16 +27,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app.id,
dataset_id=dataset_id
)
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get('tools', []) or []
tools = agent_mode.get("tools", []) or []
for tool in tools:
if len(list(tool.keys())) != 1:
continue
@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
# get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict
datasets = dataset_configs.get('datasets', {}) or {}
for dataset in datasets.get('datasets', []) or []:
datasets = dataset_configs.get("datasets", {}) or {}
for dataset in datasets.get("datasets", []) or []:
keys = list(dataset.keys())
if len(keys) == 1 and keys[0] == 'dataset':
if dataset['dataset'].get('id'):
dataset_ids.add(dataset['dataset'].get('id'))
if len(keys) == 1 and keys[0] == "dataset":
if dataset["dataset"].get("id"):
dataset_ids.add(dataset["dataset"].get("id"))
return dataset_ids

View File

@ -11,13 +11,11 @@ from models.workflow import Workflow
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
app = sender
published_workflow = kwargs.get('published_workflow')
published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id
).all()
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
if not app_dataset_joins:
@ -33,16 +31,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app.id,
dataset_id=dataset_id
)
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
if not graph:
return dataset_ids
nodes = graph.get('nodes', [])
nodes = graph.get("nodes", [])
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [node for node in nodes
if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value]
knowledge_retrieval_nodes = [
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
]
if not knowledge_retrieval_nodes:
return dataset_ids
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get('data', {}))
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
dataset_ids.update(node_data.dataset_ids)
except Exception as e:
continue

View File

@ -9,13 +9,13 @@ from models.provider import Provider
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get('application_generate_entity')
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider
).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)})
Provider.provider_name == application_generate_entity.model_conf.provider,
).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
db.session.commit()

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: message, kwargs: conversation
message_was_created = signal('message-was-created')
message_was_created = signal("message-was-created")

View File

@ -1,7 +1,7 @@
from blinker import signal
# sender: tenant
tenant_was_created = signal('tenant-was-created')
tenant_was_created = signal("tenant-was-created")
# sender: tenant
tenant_was_updated = signal('tenant-was-updated')
tenant_was_updated = signal("tenant-was-updated")

View File

@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery:
]
day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
beat_schedule = {
'clean_embedding_cache_task': {
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
'schedule': timedelta(days=day),
"clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
'clean_unused_datasets_task': {
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
'schedule': timedelta(days=day),
}
}
celery_app.conf.update(
beat_schedule=beat_schedule,
imports=imports
)
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -2,15 +2,14 @@ from flask import Flask
def init_app(app: Flask):
if app.config.get('API_COMPRESSION_ENABLED'):
if app.config.get("API_COMPRESSION_ENABLED"):
from flask_compress import Compress
app.config['COMPRESS_MIMETYPES'] = [
'application/json',
'image/svg+xml',
'text/html',
app.config["COMPRESS_MIMETYPES"] = [
"application/json",
"image/svg+xml",
"text/html",
]
compress = Compress()
compress.init_app(app)

View File

@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
POSTGRES_INDEXES_NAMING_CONVENTION = {
'ix': '%(column_0_label)s_idx',
'uq': '%(table_name)s_%(column_0_name)s_key',
'ck': '%(table_name)s_%(constraint_name)s_check',
'fk': '%(table_name)s_%(column_0_name)s_fkey',
'pk': '%(table_name)s_pkey',
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)

View File

@ -14,67 +14,69 @@ class Mail:
return self._client is not None
def init_app(self, app: Flask):
if app.config.get('MAIL_TYPE'):
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
if app.config.get("MAIL_TYPE"):
if app.config.get("MAIL_DEFAULT_SEND_FROM"):
self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
if app.config.get('MAIL_TYPE') == 'resend':
api_key = app.config.get('RESEND_API_KEY')
if app.config.get("MAIL_TYPE") == "resend":
api_key = app.config.get("RESEND_API_KEY")
if not api_key:
raise ValueError('RESEND_API_KEY is not set')
raise ValueError("RESEND_API_KEY is not set")
api_url = app.config.get('RESEND_API_URL')
api_url = app.config.get("RESEND_API_URL")
if api_url:
resend.api_url = api_url
resend.api_key = api_key
self._client = resend.Emails
elif app.config.get('MAIL_TYPE') == 'smtp':
elif app.config.get("MAIL_TYPE") == "smtp":
from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'),
username=app.config.get('SMTP_USERNAME'),
password=app.config.get('SMTP_PASSWORD'),
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
use_tls=app.config.get('SMTP_USE_TLS'),
opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
server=app.config.get("SMTP_SERVER"),
port=app.config.get("SMTP_PORT"),
username=app.config.get("SMTP_USERNAME"),
password=app.config.get("SMTP_PASSWORD"),
_from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
use_tls=app.config.get("SMTP_USE_TLS"),
opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
)
else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
else:
logging.warning('MAIL_TYPE is not set')
logging.warning("MAIL_TYPE is not set")
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client:
raise ValueError('Mail client is not initialized')
raise ValueError("Mail client is not initialized")
if not from_ and self._default_send_from:
from_ = self._default_send_from
if not from_:
raise ValueError('mail from is not set')
raise ValueError("mail from is not set")
if not to:
raise ValueError('mail to is not set')
raise ValueError("mail to is not set")
if not subject:
raise ValueError('mail subject is not set')
raise ValueError("mail subject is not set")
if not html:
raise ValueError('mail html is not set')
raise ValueError("mail html is not set")
self._client.send({
self._client.send(
{
"from": from_,
"to": to,
"subject": subject,
"html": html
})
"html": html,
}
)
def init_app(app: Flask):

View File

@ -6,18 +6,21 @@ redis_client = redis.Redis()
def init_app(app):
connection_class = Connection
if app.config.get('REDIS_USE_SSL'):
if app.config.get("REDIS_USE_SSL"):
connection_class = SSLConnection
redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('REDIS_HOST'),
'port': app.config.get('REDIS_PORT'),
'username': app.config.get('REDIS_USERNAME'),
'password': app.config.get('REDIS_PASSWORD'),
'db': app.config.get('REDIS_DB'),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)
redis_client.connection_pool = redis.ConnectionPool(
**{
"host": app.config.get("REDIS_HOST"),
"port": app.config.get("REDIS_PORT"),
"username": app.config.get("REDIS_USERNAME"),
"password": app.config.get("REDIS_PASSWORD"),
"db": app.config.get("REDIS_DB"),
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
},
connection_class=connection_class,
)
app.extensions['redis'] = redis_client
app.extensions["redis"] = redis_client

View File

@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException
def init_app(app):
if app.config.get('SENTRY_DSN'):
if app.config.get("SENTRY_DSN"):
sentry_sdk.init(
dsn=app.config.get('SENTRY_DSN'),
integrations=[
FlaskIntegration(),
CeleryIntegration()
],
dsn=app.config.get("SENTRY_DSN"),
integrations=[FlaskIntegration(), CeleryIntegration()],
ignore_errors=[HTTPException, ValueError],
traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0),
profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0),
environment=app.config.get('DEPLOY_ENV'),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}"
traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
environment=app.config.get("DEPLOY_ENV"),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
)

View File

@ -17,31 +17,19 @@ class Storage:
self.storage_runner = None
def init_app(self, app: Flask):
storage_type = app.config.get('STORAGE_TYPE')
if storage_type == 's3':
self.storage_runner = S3Storage(
app=app
)
elif storage_type == 'azure-blob':
self.storage_runner = AzureStorage(
app=app
)
elif storage_type == 'aliyun-oss':
self.storage_runner = AliyunStorage(
app=app
)
elif storage_type == 'google-storage':
self.storage_runner = GoogleStorage(
app=app
)
elif storage_type == 'tencent-cos':
self.storage_runner = TencentStorage(
app=app
)
elif storage_type == 'oci-storage':
self.storage_runner = OCIStorage(
app=app
)
storage_type = app.config.get("STORAGE_TYPE")
if storage_type == "s3":
self.storage_runner = S3Storage(app=app)
elif storage_type == "azure-blob":
self.storage_runner = AzureStorage(app=app)
elif storage_type == "aliyun-oss":
self.storage_runner = AliyunStorage(app=app)
elif storage_type == "google-storage":
self.storage_runner = GoogleStorage(app=app)
elif storage_type == "tencent-cos":
self.storage_runner = TencentStorage(app=app)
elif storage_type == "oci-storage":
self.storage_runner = OCIStorage(app=app)
else:
self.storage_runner = LocalStorage(app=app)

View File

@ -8,23 +8,22 @@ from extensions.storage.base_storage import BaseStorage
class AliyunStorage(BaseStorage):
"""Implementation for aliyun storage.
"""
"""Implementation for aliyun storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME')
self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME")
oss_auth_method = aliyun_s3.Auth
region = None
if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4':
if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4":
oss_auth_method = aliyun_s3.AuthV4
region = app_config.get('ALIYUN_OSS_REGION')
oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY'))
region = app_config.get("ALIYUN_OSS_REGION")
oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY"))
self.client = aliyun_s3.Bucket(
oss_auth,
app_config.get('ALIYUN_OSS_ENDPOINT'),
app_config.get("ALIYUN_OSS_ENDPOINT"),
self.bucket_name,
connect_timeout=30,
region=region,

View File

@ -9,16 +9,15 @@ from extensions.storage.base_storage import BaseStorage
class AzureStorage(BaseStorage):
"""Implementation for azure storage.
"""
"""Implementation for azure storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME')
self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL')
self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME')
self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY')
self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME")
self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL")
self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME")
self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY")
def save(self, filename, data):
client = self._sync_client()
@ -39,6 +38,7 @@ class AzureStorage(BaseStorage):
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob.download_blob()
yield from blob_data.chunks()
return generate(filename)
def download(self, filename, target_filepath):
@ -62,17 +62,17 @@ class AzureStorage(BaseStorage):
blob_container.delete_blob(filename)
def _sync_client(self):
cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key)
cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
cache_result = redis_client.get(cache_key)
if cache_result is not None:
sas_token = cache_result.decode('utf-8')
sas_token = cache_result.decode("utf-8")
else:
sas_token = generate_account_sas(
account_name=self.account_name,
account_key=self.account_key,
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url, credential=sas_token)

View File

@ -1,4 +1,5 @@
"""Abstract interface for file storage implementations."""
from abc import ABC, abstractmethod
from collections.abc import Generator
@ -6,8 +7,8 @@ from flask import Flask
class BaseStorage(ABC):
"""Interface for file storage.
"""
"""Interface for file storage."""
app = None
def __init__(self, app: Flask):

View File

@ -11,16 +11,16 @@ from extensions.storage.base_storage import BaseStorage
class GoogleStorage(BaseStorage):
"""Implementation for google storage.
"""
"""Implementation for google storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME')
service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME")
service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64")
# if service_account_json_str is empty, use Application Default Credentials
if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode('utf-8')
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object
service_account_obj = json.loads(service_account_json)
self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
@ -43,9 +43,10 @@ class GoogleStorage(BaseStorage):
def generate(filename: str = filename) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
with closing(blob.open(mode='rb')) as blob_stream:
with closing(blob.open(mode="rb")) as blob_stream:
while chunk := blob_stream.read(4096):
yield chunk
return generate()
def download(self, filename, target_filepath):

View File

@ -8,21 +8,20 @@ from extensions.storage.base_storage import BaseStorage
class LocalStorage(BaseStorage):
"""Implementation for local storage.
"""
"""Implementation for local storage."""
def __init__(self, app: Flask):
super().__init__(app)
folder = self.app.config.get('STORAGE_LOCAL_PATH')
folder = self.app.config.get("STORAGE_LOCAL_PATH")
if not os.path.isabs(folder):
folder = os.path.join(app.root_path, folder)
self.folder = folder
def save(self, filename, data):
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True)
@ -31,10 +30,10 @@ class LocalStorage(BaseStorage):
f.write(data)
def load_once(self, filename: str) -> bytes:
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
@ -46,10 +45,10 @@ class LocalStorage(BaseStorage):
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
@ -61,10 +60,10 @@ class LocalStorage(BaseStorage):
return generate()
def download(self, filename, target_filepath):
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
@ -72,17 +71,17 @@ class LocalStorage(BaseStorage):
shutil.copyfile(filename, target_filepath)
def exists(self, filename):
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
return os.path.exists(filename)
def delete(self, filename):
if not self.folder or self.folder.endswith('/'):
if not self.folder or self.folder.endswith("/"):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
filename = self.folder + "/" + filename
if os.path.exists(filename):
os.remove(filename)

View File

@ -12,13 +12,13 @@ class OCIStorage(BaseStorage):
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('OCI_BUCKET_NAME')
self.bucket_name = app_config.get("OCI_BUCKET_NAME")
self.client = boto3.client(
's3',
aws_secret_access_key=app_config.get('OCI_SECRET_KEY'),
aws_access_key_id=app_config.get('OCI_ACCESS_KEY'),
endpoint_url=app_config.get('OCI_ENDPOINT'),
region_name=app_config.get('OCI_REGION')
"s3",
aws_secret_access_key=app_config.get("OCI_SECRET_KEY"),
aws_access_key_id=app_config.get("OCI_ACCESS_KEY"),
endpoint_url=app_config.get("OCI_ENDPOINT"),
region_name=app_config.get("OCI_REGION"),
)
def save(self, filename, data):
@ -27,9 +27,9 @@ class OCIStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
try:
with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
@ -40,12 +40,13 @@ class OCIStorage(BaseStorage):
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].iter_chunks()
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
return generate()
def download(self, filename, target_filepath):

View File

@ -10,23 +10,23 @@ from extensions.storage.base_storage import BaseStorage
class S3Storage(BaseStorage):
"""Implementation for s3 storage.
"""
"""Implementation for s3 storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('S3_BUCKET_NAME')
if app_config.get('S3_USE_AWS_MANAGED_IAM'):
self.bucket_name = app_config.get("S3_BUCKET_NAME")
if app_config.get("S3_USE_AWS_MANAGED_IAM"):
session = boto3.Session()
self.client = session.client('s3')
self.client = session.client("s3")
else:
self.client = boto3.client(
's3',
aws_secret_access_key=app_config.get('S3_SECRET_KEY'),
aws_access_key_id=app_config.get('S3_ACCESS_KEY'),
endpoint_url=app_config.get('S3_ENDPOINT'),
region_name=app_config.get('S3_REGION'),
config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')})
"s3",
aws_secret_access_key=app_config.get("S3_SECRET_KEY"),
aws_access_key_id=app_config.get("S3_ACCESS_KEY"),
endpoint_url=app_config.get("S3_ENDPOINT"),
region_name=app_config.get("S3_REGION"),
config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}),
)
def save(self, filename, data):
@ -35,9 +35,9 @@ class S3Storage(BaseStorage):
def load_once(self, filename: str) -> bytes:
try:
with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
@ -48,12 +48,13 @@ class S3Storage(BaseStorage):
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].iter_chunks()
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
return generate()
def download(self, filename, target_filepath):

View File

@ -7,18 +7,17 @@ from extensions.storage.base_storage import BaseStorage
class TencentStorage(BaseStorage):
"""Implementation for tencent cos storage.
"""
"""Implementation for tencent cos storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME')
self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME")
config = CosConfig(
Region=app_config.get('TENCENT_COS_REGION'),
SecretId=app_config.get('TENCENT_COS_SECRET_ID'),
SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'),
Scheme=app_config.get('TENCENT_COS_SCHEME'),
Region=app_config.get("TENCENT_COS_REGION"),
SecretId=app_config.get("TENCENT_COS_SECRET_ID"),
SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"),
Scheme=app_config.get("TENCENT_COS_SCHEME"),
)
self.client = CosS3Client(config)
@ -26,19 +25,19 @@ class TencentStorage(BaseStorage):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
def load_once(self, filename: str) -> bytes:
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read()
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].get_stream(chunk_size=4096)
yield from response["Body"].get_stream(chunk_size=4096)
return generate()
def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response['Body'].get_stream_to_file(target_filepath)
response["Body"].get_stream_to_file(target_filepath)
def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)

View File

@ -5,7 +5,7 @@ from libs.helper import TimestampField
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute='content'),
"answer": fields.Raw(attribute="content"),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(simple_account_fields, allow_null=True)
@ -21,8 +21,8 @@ annotation_hit_history_fields = {
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute='annotation_question'),
"response": fields.String(attribute='annotation_content')
"match": fields.String(attribute="annotation_question"),
"response": fields.String(attribute="annotation_content"),
}
annotation_hit_history_list_fields = {

View File

@ -8,16 +8,16 @@ class HiddenAPIKey(fields.Raw):
api_key = obj.api_key
# If the length of the api_key is less than 8 characters, show the first and last characters
if len(api_key) <= 8:
return api_key[0] + '******' + api_key[-1]
return api_key[0] + "******" + api_key[-1]
# If the api_key is greater than 8 characters, show the first three and the last three characters
else:
return api_key[:3] + '******' + api_key[-3:]
return api_key[:3] + "******" + api_key[-3:]
api_based_extension_fields = {
'id': fields.String,
'name': fields.String,
'api_endpoint': fields.String,
'api_key': HiddenAPIKey,
'created_at': TimestampField
"id": fields.String,
"name": fields.String,
"api_endpoint": fields.String,
"api_key": HiddenAPIKey,
"created_at": TimestampField,
}

View File

@ -3,157 +3,153 @@ from flask_restful import fields
from libs.helper import TimestampField
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'),
'icon': fields.String,
'icon_background': fields.String,
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
"data": fields.List(fields.Nested(app_detail_kernel_fields)),
"total": fields.Integer,
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'text_to_speech': fields.Raw(attribute='text_to_speech_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
'file_upload': fields.Raw(attribute='file_upload_dict'),
'created_at': TimestampField
"opening_statement": fields.String,
"suggested_questions": fields.Raw(attribute="suggested_questions_list"),
"suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"),
"speech_to_text": fields.Raw(attribute="speech_to_text_dict"),
"text_to_speech": fields.Raw(attribute="text_to_speech_dict"),
"retriever_resource": fields.Raw(attribute="retriever_resource_dict"),
"annotation_reply": fields.Raw(attribute="annotation_reply_dict"),
"more_like_this": fields.Raw(attribute="more_like_this_dict"),
"sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"),
"external_data_tools": fields.Raw(attribute="external_data_tools_list"),
"model": fields.Raw(attribute="model_dict"),
"user_input_form": fields.Raw(attribute="user_input_form_list"),
"dataset_query_variable": fields.String,
"pre_prompt": fields.String,
"agent_mode": fields.Raw(attribute="agent_mode_dict"),
"prompt_type": fields.String,
"chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"),
"completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"),
"dataset_configs": fields.Raw(attribute="dataset_configs_dict"),
"file_upload": fields.Raw(attribute="file_upload_dict"),
"created_at": TimestampField,
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'),
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
'tracing': fields.Raw,
'created_at': TimestampField
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
"tracing": fields.Raw,
"created_at": TimestampField,
}
prompt_config_fields = {
'prompt_template': fields.String,
"prompt_template": fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'max_active_requests': fields.Raw(),
'description': fields.String(attribute='desc_or_prompt'),
'mode': fields.String(attribute='mode_compatible_with_agent'),
'icon': fields.String,
'icon_background': fields.String,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
'created_at': TimestampField,
'tags': fields.List(fields.Nested(tag_fields))
"id": fields.String,
"name": fields.String,
"max_active_requests": fields.Raw(),
"description": fields.String(attribute="desc_or_prompt"),
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True),
"created_at": TimestampField,
"tags": fields.List(fields.Nested(tag_fields)),
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
}
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
"name": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"mode": fields.String,
"model_config": fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
"data": fields.List(fields.Nested(template_fields)),
}
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'chat_color_theme': fields.String,
'chat_color_theme_inverted': fields.Boolean,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'custom_disclaimer': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
'show_workflow_steps': fields.Boolean,
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"default_language": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"app_base_url": fields.String,
"show_workflow_steps": fields.Boolean,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'mode': fields.String(attribute='mode_compatible_with_agent'),
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField,
'deleted_tools': fields.List(fields.String),
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"created_at": TimestampField,
"deleted_tools": fields.List(fields.String),
}
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'custom_disclaimer': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'show_workflow_steps': fields.Boolean,
"app_id": fields.String,
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"default_language": fields.String,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean,
}

View File

@ -6,205 +6,202 @@ from libs.helper import TimestampField
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
return value[0]["text"] if value else ""
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(simple_account_fields, allow_null=True),
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_fields, allow_null=True),
}
annotation_fields = {
'id': fields.String,
'question': fields.String,
'content': fields.String,
'account': fields.Nested(simple_account_fields, allow_null=True),
'created_at': TimestampField
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
annotation_hit_history_fields = {
'annotation_id': fields.String(attribute='id'),
'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True),
'created_at': TimestampField
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
message_file_fields = {
'id': fields.String,
'type': fields.String,
'url': fields.String,
'belongs_to': fields.String(default='user'),
"id": fields.String,
"type": fields.String,
"url": fields.String,
"belongs_to": fields.String(default="user"),
}
agent_thought_fields = {
'id': fields.String,
'chain_id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_labels': fields.Raw,
'tool_input': fields.String,
'created_at': TimestampField,
'observation': fields.String,
'files': fields.List(fields.String),
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String(attribute='re_sign_file_url_answer'),
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'workflow_run_id': fields.String,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'metadata': fields.Raw(attribute='message_metadata_dict'),
'status': fields.String,
'error': fields.String,
"id": fields.String,
"conversation_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_fields)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
}
simple_configs_fields = {
'prompt_template': fields.String,
"prompt_template": fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
"inputs": fields.Raw,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"model_config": fields.Nested(simple_model_config_fields),
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
}
conversation_message_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_fields),
"message": fields.Nested(message_detail_fields, attribute="first_message"),
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'name': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
}
conversation_with_summary_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
}
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'introduction': fields.String,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
}
simple_conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
"id": fields.String,
"name": fields.String,
"inputs": fields.Raw,
"status": fields.String,
"introduction": fields.String,
"created_at": TimestampField,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(simple_conversation_fields))
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(simple_conversation_fields)),
}
conversation_with_model_config_fields = {
**simple_conversation_fields,
'model_config': fields.Raw,
"model_config": fields.Raw,
}
conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(conversation_with_model_config_fields)),
}

View File

@ -3,19 +3,19 @@ from flask_restful import fields
from libs.helper import TimestampField
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.String,
'description': fields.String,
'created_at': TimestampField,
'updated_at': TimestampField,
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value": fields.String,
"description": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
paginated_conversation_variable_fields = {
'page': fields.Integer,
'limit': fields.Integer,
'total': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"),
}

View File

@ -2,64 +2,56 @@ from flask_restful import fields
from libs.helper import TimestampField
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(integrate_page_fields)),
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
"notion_info": fields.List(fields.Nested(integrate_workspace_fields)),
}
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
"parent_id": fields.String,
"type": fields.String,
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(integrate_page_fields)),
"total": fields.Integer,
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
"id": fields.String,
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"disabled": fields.Boolean,
"link": fields.String,
"source_info": fields.Nested(integrate_workspace_fields),
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
"data": fields.List(fields.Nested(integrate_fields)),
}

View File

@ -3,73 +3,64 @@ from flask_restful import fields
from libs.helper import TimestampField
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
"id": fields.String,
"name": fields.String,
"description": fields.String,
"permission": fields.String,
"data_source_type": fields.String,
"indexing_technique": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
}
reranking_model_fields = {
'reranking_provider_name': fields.String,
'reranking_model_name': fields.String
}
reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
keyword_setting_fields = {
'keyword_weight': fields.Float
}
keyword_setting_fields = {"keyword_weight": fields.Float}
vector_setting_fields = {
'vector_weight': fields.Float,
'embedding_model_name': fields.String,
'embedding_provider_name': fields.String,
"vector_weight": fields.Float,
"embedding_model_name": fields.String,
"embedding_provider_name": fields.String,
}
weighted_score_fields = {
'keyword_setting': fields.Nested(keyword_setting_fields),
'vector_setting': fields.Nested(vector_setting_fields),
"keyword_setting": fields.Nested(keyword_setting_fields),
"vector_setting": fields.Nested(vector_setting_fields),
}
dataset_retrieval_model_fields = {
'search_method': fields.String,
'reranking_enable': fields.Boolean,
'reranking_mode': fields.String,
'reranking_model': fields.Nested(reranking_model_fields),
'weights': fields.Nested(weighted_score_fields, allow_null=True),
'top_k': fields.Integer,
'score_threshold_enabled': fields.Boolean,
'score_threshold': fields.Float
"search_method": fields.String,
"reranking_enable": fields.Boolean,
"reranking_mode": fields.String,
"reranking_model": fields.Nested(reranking_model_fields),
"weights": fields.Nested(weighted_score_fields, allow_null=True),
"top_k": fields.Integer,
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
'tags': fields.List(fields.Nested(tag_fields))
"id": fields.String,
"name": fields.String,
"description": fields.String,
"provider": fields.String,
"permission": fields.String,
"data_source_type": fields.String,
"indexing_technique": fields.String,
"app_count": fields.Integer,
"document_count": fields.Integer,
"word_count": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"embedding_model": fields.String,
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)),
}
dataset_query_detail_fields = {
@ -79,7 +70,5 @@ dataset_query_detail_fields = {
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
"created_at": TimestampField,
}

View File

@ -4,75 +4,73 @@ from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
"id": fields.String,
"position": fields.Integer,
"data_source_type": fields.String,
"data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String,
"name": fields.String,
"created_from": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"tokens": fields.Integer,
"indexing_status": fields.String,
"error": fields.String,
"enabled": fields.Boolean,
"disabled_at": TimestampField,
"disabled_by": fields.String,
"archived": fields.Boolean,
"display_status": fields.String,
"word_count": fields.Integer,
"hit_count": fields.Integer,
"doc_form": fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
"id": fields.String,
"position": fields.Integer,
"data_source_type": fields.String,
"data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String,
"name": fields.String,
"created_from": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"tokens": fields.Integer,
"indexing_status": fields.String,
"error": fields.String,
"enabled": fields.Boolean,
"disabled_at": TimestampField,
"disabled_by": fields.String,
"archived": fields.Boolean,
"display_status": fields.String,
"word_count": fields.Integer,
"hit_count": fields.Integer,
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
}
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
"dataset": fields.Nested(dataset_fields),
"documents": fields.List(fields.Nested(document_fields)),
"batch": fields.String,
}
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
"id": fields.String,
"indexing_status": fields.String,
"processing_started_at": TimestampField,
"parsing_completed_at": TimestampField,
"cleaning_completed_at": TimestampField,
"splitting_completed_at": TimestampField,
"completed_at": TimestampField,
"paused_at": TimestampField,
"error": fields.String,
"stopped_at": TimestampField,
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))}

View File

@ -1,8 +1,8 @@
from flask_restful import fields
simple_end_user_fields = {
'id': fields.String,
'type': fields.String,
'is_anonymous': fields.Boolean,
'session_id': fields.String,
"id": fields.String,
"type": fields.String,
"is_anonymous": fields.Boolean,
"session_id": fields.String,
}

View File

@ -3,17 +3,17 @@ from flask_restful import fields
from libs.helper import TimestampField
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer,
'image_file_size_limit': fields.Integer,
"file_size_limit": fields.Integer,
"batch_count_limit": fields.Integer,
"image_file_size_limit": fields.Integer,
}
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
}

View File

@ -3,39 +3,39 @@ from flask_restful import fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
"id": fields.String,
"data_source_type": fields.String,
"name": fields.String,
"doc_type": fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
"id": fields.String,
"position": fields.Integer,
"document_id": fields.String,
"content": fields.String,
"answer": fields.String,
"word_count": fields.Integer,
"tokens": fields.Integer,
"keywords": fields.List(fields.String),
"index_node_id": fields.String,
"index_node_hash": fields.String,
"hit_count": fields.Integer,
"enabled": fields.Boolean,
"disabled_at": TimestampField,
"disabled_by": fields.String,
"status": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"indexing_at": TimestampField,
"completed_at": TimestampField,
"error": fields.String,
"stopped_at": TimestampField,
"document": fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
"segment": fields.Nested(segment_fields),
"score": fields.Float,
"tsne_position": fields.Raw,
}

View File

@ -3,23 +3,21 @@ from flask_restful import fields
from libs.helper import TimestampField
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_background": fields.String,
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean
"id": fields.String,
"app": fields.Nested(app_fields),
"app_owner_tenant_id": fields.String,
"is_pinned": fields.Boolean,
"last_used_at": TimestampField,
"editable": fields.Boolean,
"uninstallable": fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))}

View File

@ -2,38 +2,32 @@ from flask_restful import fields
from libs.helper import TimestampField
simple_account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'is_password_set': fields.Boolean,
'interface_language': fields.String,
'interface_theme': fields.String,
'timezone': fields.String,
'last_login_at': TimestampField,
'last_login_ip': fields.String,
'created_at': TimestampField
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"email": fields.String,
"is_password_set": fields.Boolean,
"interface_language": fields.String,
"interface_theme": fields.String,
"timezone": fields.String,
"last_login_at": TimestampField,
"last_login_ip": fields.String,
"created_at": TimestampField,
}
account_with_role_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'last_login_at': TimestampField,
'last_active_at': TimestampField,
'created_at': TimestampField,
'role': fields.String,
'status': fields.String,
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"email": fields.String,
"last_login_at": TimestampField,
"last_active_at": TimestampField,
"created_at": TimestampField,
"role": fields.String,
"status": fields.String,
}
account_with_role_list_fields = {
'accounts': fields.List(fields.Nested(account_with_role_fields))
}
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}

View File

@ -3,83 +3,79 @@ from flask_restful import fields
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
feedback_fields = {
'rating': fields.String
}
feedback_fields = {"rating": fields.String}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
feedback_fields = {
'rating': fields.String
}
feedback_fields = {"rating": fields.String}
agent_thought_fields = {
'id': fields.String,
'chain_id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_labels': fields.Raw,
'tool_input': fields.String,
'created_at': TimestampField,
'observation': fields.String,
'files': fields.List(fields.String)
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String(attribute='re_sign_file_url_answer'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'status': fields.String,
'error': fields.String,
"id": fields.String,
"conversation_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"status": fields.String,
"error": fields.String,
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}

View File

@ -3,31 +3,31 @@ from flask_restful import fields
from libs.helper import TimestampField
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
"id": fields.String,
"position": fields.Integer,
"document_id": fields.String,
"content": fields.String,
"answer": fields.String,
"word_count": fields.Integer,
"tokens": fields.Integer,
"keywords": fields.List(fields.String),
"index_node_id": fields.String,
"index_node_hash": fields.String,
"hit_count": fields.Integer,
"enabled": fields.Boolean,
"disabled_at": TimestampField,
"disabled_by": fields.String,
"status": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"indexing_at": TimestampField,
"completed_at": TimestampField,
"error": fields.String,
"stopped_at": TimestampField,
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
"data": fields.List(fields.Nested(segment_fields)),
"has_more": fields.Boolean,
"limit": fields.Integer,
}

View File

@ -1,8 +1,3 @@
from flask_restful import fields
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String,
'binding_count': fields.String
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}

View File

@ -7,18 +7,18 @@ from libs.helper import TimestampField
workflow_app_log_partial_fields = {
"id": fields.String,
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True),
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
"created_from": fields.String,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
"created_at": TimestampField
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField,
}
workflow_app_log_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items')
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"),
}

View File

@ -13,43 +13,43 @@ class EnvironmentVariableField(fields.Raw):
# Mask secret variables values in environment_variables
if isinstance(value, SecretVariable):
return {
'id': value.id,
'name': value.name,
'value': encrypter.obfuscated_token(value.value),
'value_type': value.value_type.value,
"id": value.id,
"name": value.name,
"value": encrypter.obfuscated_token(value.value),
"value_type": value.value_type.value,
}
if isinstance(value, Variable):
return {
'id': value.id,
'name': value.name,
'value': value.value,
'value_type': value.value_type.value,
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.value,
}
if isinstance(value, dict):
value_type = value.get('value_type')
value_type = value.get("value_type")
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f'Unsupported environment variable value type: {value_type}')
raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value
conversation_variable_fields = {
'id': fields.String,
'name': fields.String,
'value_type': fields.String(attribute='value_type.value'),
'value': fields.Raw,
'description': fields.String,
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value": fields.Raw,
"description": fields.String,
}
workflow_fields = {
'id': fields.String,
'graph': fields.Raw(attribute='graph_dict'),
'features': fields.Raw(attribute='features_dict'),
'hash': fields.String(attribute='unique_hash'),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
'created_at': TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
'updated_at': TimestampField,
'tool_published': fields.Boolean,
'environment_variables': fields.List(EnvironmentVariableField()),
'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)),
"id": fields.String,
"graph": fields.Raw(attribute="graph_dict"),
"features": fields.Raw(attribute="features_dict"),
"hash": fields.String(attribute="unique_hash"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
"tool_published": fields.Boolean,
"environment_variables": fields.List(EnvironmentVariableField()),
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
}

View File

@ -13,7 +13,7 @@ workflow_run_for_log_fields = {
"total_tokens": fields.Integer,
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField
"finished_at": TimestampField,
}
workflow_run_for_list_fields = {
@ -24,9 +24,9 @@ workflow_run_for_list_fields = {
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField
"finished_at": TimestampField,
}
advanced_chat_workflow_run_for_list_fields = {
@ -39,40 +39,40 @@ advanced_chat_workflow_run_for_list_fields = {
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField
"finished_at": TimestampField,
}
advanced_chat_workflow_run_pagination_fields = {
'limit': fields.Integer(attribute='limit'),
'has_more': fields.Boolean(attribute='has_more'),
'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data')
"limit": fields.Integer(attribute="limit"),
"has_more": fields.Boolean(attribute="has_more"),
"data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"),
}
workflow_run_pagination_fields = {
'limit': fields.Integer(attribute='limit'),
'has_more': fields.Boolean(attribute='has_more'),
'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data')
"limit": fields.Integer(attribute="limit"),
"has_more": fields.Boolean(attribute="has_more"),
"data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"),
}
workflow_run_detail_fields = {
"id": fields.String,
"sequence_number": fields.Integer,
"version": fields.String,
"graph": fields.Raw(attribute='graph_dict'),
"inputs": fields.Raw(attribute='inputs_dict'),
"graph": fields.Raw(attribute="graph_dict"),
"inputs": fields.Raw(attribute="inputs_dict"),
"status": fields.String,
"outputs": fields.Raw(attribute='outputs_dict'),
"outputs": fields.Raw(attribute="outputs_dict"),
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_steps": fields.Integer,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField
"finished_at": TimestampField,
}
workflow_run_node_execution_fields = {
@ -82,21 +82,21 @@ workflow_run_node_execution_fields = {
"node_id": fields.String,
"node_type": fields.String,
"title": fields.String,
"inputs": fields.Raw(attribute='inputs_dict'),
"process_data": fields.Raw(attribute='process_data_dict'),
"outputs": fields.Raw(attribute='outputs_dict'),
"inputs": fields.Raw(attribute="inputs_dict"),
"process_data": fields.Raw(attribute="process_data_dict"),
"outputs": fields.Raw(attribute="outputs_dict"),
"status": fields.String,
"error": fields.String,
"elapsed_time": fields.Float,
"execution_metadata": fields.Raw(attribute='execution_metadata_dict'),
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
"extras": fields.Raw,
"created_at": TimestampField,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
"finished_at": TimestampField
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField,
}
workflow_run_node_execution_list_fields = {
'data': fields.List(fields.Nested(workflow_run_node_execution_fields)),
"data": fields.List(fields.Nested(workflow_run_node_execution_fields)),
}

View File

@ -69,7 +69,18 @@ ignore = [
]
[tool.ruff.format]
quote-style = "single"
exclude = [
"core/**/*.py",
"controllers/**/*.py",
"models/**/*.py",
"utils/**/*.py",
"migrations/**/*",
"services/**/*.py",
"tasks/**/*.py",
"tests/**/*.py",
"libs/**/*.py",
"configs/**/*.py",
]
[tool.pytest_env]
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"

View File

@ -11,27 +11,32 @@ from extensions.ext_database import db
from models.dataset import Embedding
@app.celery.task(queue='dataset')
@app.celery.task(queue="dataset")
def clean_embedding_cache_task():
click.echo(click.style('Start clean embedding cache.', fg='green'))
click.echo(click.style("Start clean embedding cache.", fg="green"))
clean_days = int(dify_config.CLEAN_DAY_SETTING)
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True:
try:
embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \
.order_by(Embedding.created_at.desc()).limit(100).all()
embedding_ids = (
db.session.query(Embedding.id)
.filter(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
except NotFound:
break
if embedding_ids:
for embedding_id in embedding_ids:
db.session.execute(text(
"DELETE FROM embeddings WHERE id = :embedding_id"
), {'embedding_id': embedding_id})
db.session.execute(
text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id}
)
db.session.commit()
else:
break
end_at = time.perf_counter()
click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green'))
click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green"))

View File

@ -12,9 +12,9 @@ from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document
@app.celery.task(queue='dataset')
@app.celery.task(queue="dataset")
def clean_unused_datasets_task():
click.echo(click.style('Start clean unused datasets indexes.', fg='green'))
click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
clean_days = dify_config.CLEAN_DAY_SETTING
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
@ -22,40 +22,44 @@ def clean_unused_datasets_task():
while True:
try:
# Subquery for counting new documents
document_subquery_new = db.session.query(
Document.dataset_id,
func.count(Document.id).label('document_count')
).filter(
Document.indexing_status == 'completed',
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > thirty_days_ago
).group_by(Document.dataset_id).subquery()
Document.updated_at > thirty_days_ago,
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = db.session.query(
Document.dataset_id,
func.count(Document.id).label('document_count')
).filter(
Document.indexing_status == 'completed',
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < thirty_days_ago
).group_by(Document.dataset_id).subquery()
Document.updated_at < thirty_days_ago,
)
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter
datasets = (db.session.query(Dataset)
.outerjoin(
document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id
).outerjoin(
document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id
).filter(
datasets = (
db.session.query(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < thirty_days_ago,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0
).order_by(
Dataset.created_at.desc()
).paginate(page=page, per_page=50))
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
@ -63,10 +67,11 @@ def clean_unused_datasets_task():
break
page += 1
for dataset in datasets:
dataset_query = db.session.query(DatasetQuery).filter(
DatasetQuery.created_at > thirty_days_ago,
DatasetQuery.dataset_id == dataset.id
).all()
dataset_query = (
db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
# remove index
@ -74,17 +79,14 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None)
# update document
update_params = {
Document.enabled: False
}
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
fg='green'))
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e:
click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
end_at = time.perf_counter()
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green"))

View File

@ -11,5 +11,8 @@ fi
# run ruff linter
ruff check --fix ./api
# run ruff formatter
ruff format ./api
# run dotenv-linter linter
dotenv-linter ./api/.env.example ./web/.env.example