chore: fix indention violations by applying E111 to E117 ruff rules (#4925)

This commit is contained in:
Bowen Liang 2024-06-05 14:05:15 +08:00 committed by GitHub
parent 6b6afb7708
commit f32b440c4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 73 additions and 61 deletions

View File

@ -36,7 +36,7 @@ jobs:
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: ruff check ./api
run: ruff check --preview ./api
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -528,4 +528,3 @@ class BaseAgentRunner(AppRunner):
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)

View File

@ -57,23 +57,23 @@ class BaichuanModel:
}[model]
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
return message
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
return message
def _handle_chat_stream_generate_response(self, response) -> Generator:
for line in response.iter_lines():

View File

@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
model_prefix = model.split('.')[0]
if model_prefix == "amazon" :
for text in texts:
body = {
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
result = TextEmbeddingResult(
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
)
return result
)
return result
if model_prefix == "cohere" :
input_type = 'search_document' if len(texts) > 1 else 'search_query'
for text in texts:
body = {
input_type = 'search_document' if len(texts) > 1 else 'search_query'
for text in texts:
body = {
"texts": [text],
"input_type": input_type,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
token_usage += len(text)
result = TextEmbeddingResult(
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
token_usage += len(text)
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
)
return result
)
return result
#others
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
)
return usage
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
content_type = 'application/json'
try:
response = bedrock_runtime.invoke_model(
body=json.dumps(body),
modelId=model,
accept=accept,
body=json.dumps(body),
modelId=model,
accept=accept,
contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))

View File

@ -54,7 +54,7 @@ class PGVectoRS(BaseVector):
class _Table(CollectionORM):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True} # noqa: RUF012
__table_args__ = {"extend_existing": True}
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,

View File

@ -190,7 +190,7 @@ class RelytVector(BaseVector):
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e)) # noqa: T201
print("Delete operation failed:", str(e))
return False
def delete_by_metadata_field(self, key: str, value: str):

View File

@ -50,7 +50,7 @@ class BaseDocumentTransformer(ABC):
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
"""
@abstractmethod
def transform_documents(

View File

@ -68,7 +68,7 @@ class ArxivAPIWrapper(BaseModel):
Args:
query: a plaintext search query
""" # noqa: E501
"""
try:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results

View File

@ -121,4 +121,5 @@ class SearXNGSearchTool(BuiltinTool):
query=query,
search_type=search_type,
result_type=result_type,
topK=num_results)
topK=num_results
)

View File

@ -30,7 +30,7 @@ class TwilioAPIWrapper(BaseModel):
Twilio also work here. You cannot, for example, spoof messages from a private
cell phone number. If you are using `messaging_service_sid`, this parameter
must be empty.
""" # noqa: E501
"""
@validator("client", pre=True, always=True)
def set_validator(cls, values: dict) -> dict:
@ -60,7 +60,7 @@ class TwilioAPIWrapper(BaseModel):
SMS/MMS or
[Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
for other 3rd-party channels.
""" # noqa: E501
"""
message = self.client.messages.create(to, from_=self.from_number, body=body)
return message.sid

View File

@ -332,10 +332,11 @@ class Tool(BaseModel, ABC):
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT,
message=text,
save_as=save_as
)
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=text,
save_as=save_as
)
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
@ -344,7 +345,8 @@ class Tool(BaseModel, ABC):
:param blob: the blob
:return: the blob message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB,
message=blob, meta=meta,
save_as=save_as
)
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=blob, meta=meta,
save_as=save_as
)

View File

@ -13,8 +13,18 @@ select = [
"F", # pyflakes rules
"I", # isort rules
"UP", # pyupgrade rules
"E101", # mixed-spaces-and-tabs
"E111", # indentation-with-invalid-multiple
"E112", # no-indented-block
"E113", # unexpected-indentation
"E115", # no-indented-block-comment
"E116", # unexpected-indentation-comment
"E117", # over-indented
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"S506", # unsafe-yaml-load
"W191", # tab-indentation
"W605", # invalid-escape-sequence
]
ignore = [

View File

@ -9,7 +9,7 @@ if ! command -v ruff &> /dev/null; then
fi
# run ruff linter
ruff check --fix ./api
ruff check --fix --preview ./api
# env files linting relies on `dotenv-linter` in path
if ! command -v dotenv-linter &> /dev/null; then

View File

@ -31,7 +31,7 @@ if $api_modified; then
pip install ruff
fi
ruff check ./api || status=$?
ruff check --preview ./api || status=$?
status=${status:-0}