mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
compatible with original provider name
This commit is contained in:
parent
8c2dbe876f
commit
f798add31c
|
@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument
|
|||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
|
@ -639,6 +640,18 @@ where sites.id is null limit 1000"""
|
|||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||
def migrate_data_for_plugin():
|
||||
"""
|
||||
Migrate data for plugin.
|
||||
"""
|
||||
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||
|
||||
PluginDataMigration.migrate()
|
||||
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
|
@ -649,3 +662,4 @@ def register_commands(app):
|
|||
app.cli.add_command(create_tenant)
|
||||
app.cli.add_command(upgrade_db)
|
||||
app.cli.add_command(fix_app_site_missing)
|
||||
app.cli.add_command(migrate_data_for_plugin)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
|
@ -53,7 +54,15 @@ class ModelConfigManager:
|
|||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
|
@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel):
|
|||
return list(self.values())
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel):
|
|||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
|
|
0
api/services/plugin/__init__.py
Normal file
0
api/services/plugin/__init__.py
Normal file
184
api/services/plugin/data_migration.py
Normal file
184
api/services/plugin/data_migration.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
cls.migrate_db_records("providers", "provider_name") # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name")
|
||||
cls.migrate_db_records("provider_orders", "provider_name")
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name")
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name")
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
||||
cls.migrate_datasets()
|
||||
cls.migrate_db_records("embeddings", "provider_name") # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
print(type(retrieval_model))
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
retrieval_model_changed = False
|
||||
if retrieval_model:
|
||||
if (
|
||||
"reranking_model" in retrieval_model
|
||||
and "reranking_provider_name" in retrieval_model["reranking_model"]
|
||||
and retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
):
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating {table_name} {record_id} "
|
||||
f"(reranking_provider_name: "
|
||||
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
||||
)
|
||||
retrieval_model_changed = True
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
params = {"record_id": record_id}
|
||||
update_retrieval_model_sql = ""
|
||||
if retrieval_model and retrieval_model_changed:
|
||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), {"record_id": record_id})
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
Loading…
Reference in New Issue
Block a user