Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class OpenSearchClientArguments(TypedDict, total=False):
hosts: str | list[dict] | None
use_ssl: bool
verify_certs: bool
url_prefix: str
timeout: int
headers: dict[str, str] | None
http_compress: bool | None
opaque_id: str | None
scheme: str
hosts: str | list[dict] | None
connection_class: type[OpenSearchConnectionClass] | None
http_auth: tuple[str, str]

Expand Down Expand Up @@ -64,8 +71,18 @@ def __init__(
self.conn_id = open_search_conn_id
self.log_query = log_query

self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", False)))
self.verify_certs = to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
extra = self.conn.extra_dejson
self.use_ssl = to_boolean(str(extra.get("use_ssl", False)))
self.verify_certs = to_boolean(str(extra.get("verify_certs", False)))
self.url_prefix = extra.get("url_prefix")
self.timeout = int(extra["timeout"]) if extra.get("timeout") is not None else None
self.headers = extra.get("headers")
self.http_compress = (
to_boolean(str(extra["http_compress"])) if extra.get("http_compress") is not None else None
)
self.opaque_id = extra.get("opaque_id")
self.scheme = extra.get("scheme")

self.connection_class = open_search_conn_class
self.__SERVICE = "es"

Expand All @@ -82,6 +99,18 @@ def client(self) -> OpenSearch:
verify_certs=self.verify_certs,
connection_class=self.connection_class,
)
if self.scheme:
client_args["scheme"] = self.scheme
if self.url_prefix:
client_args["url_prefix"] = self.url_prefix
if self.timeout is not None:
client_args["timeout"] = self.timeout
if self.headers is not None:
client_args["headers"] = self.headers
if self.http_compress is not None:
client_args["http_compress"] = self.http_compress
if self.opaque_id is not None:
client_args["opaque_id"] = self.opaque_id
if self.conn.login and self.conn.password:
client_args["http_auth"] = (self.conn.login, self.conn.password)
client = OpenSearch(**client_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,39 @@ def test_delete_check_parameters(self):
hook.delete(index_name="test_index")

@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_hook_param_bool(self, mock_get_connection):
@mock.patch("airflow.providers.opensearch.hooks.opensearch.OpenSearch")
def test_hook_extra_params(self, mock_opensearch, mock_get_connection):
mock_conn = Connection(
conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"}
conn_id="opensearch_default",
host="opensearch.local",
port=9200,
extra={
"use_ssl": True,
"verify_certs": False,
"scheme": "https",
"timeout": 30,
"http_compress": True,
"url_prefix": "os",
"headers": {"x-trace-id": "abc123"},
"opaque_id": "request-1",
},
)
mock_get_connection.return_value = mock_conn
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)

assert isinstance(hook.use_ssl, bool)
assert isinstance(hook.verify_certs, bool)
hook.client

mock_opensearch.assert_called_once_with(
hosts=[{"host": "opensearch.local", "port": 9200}],
use_ssl=True,
verify_certs=False,
scheme="https",
timeout=30,
http_compress=True,
url_prefix="os",
headers={"x-trace-id": "abc123"},
opaque_id="request-1",
connection_class=DEFAULT_CONN,
)

def test_load_conn_param(self, mock_hook):
hook_default = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
Expand Down
Loading