Skip to content

Commit 76cf664

Browse files
authored
Support pre-filtering RagnarStores in ragnar_retrieve() (#29)
* use new duckdb array->matrix support * Add support for tbl_sql in ragnar_retrieve functions * accept store-tbl in `ragnar_retrieve()`; support pre-filtering * redocument * `R CMD check` fixes * use mock `embed` function in example
1 parent da608a6 commit 76cf664

11 files changed

Lines changed: 327 additions & 21 deletions

DESCRIPTION

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ License: MIT + file LICENSE
1515
Encoding: UTF-8
1616
Roxygen: list(markdown = TRUE)
1717
RoxygenNote: 7.3.2
18+
Depends:
19+
R (>= 4.3.0)
1820
Imports:
1921
DBI,
20-
duckdb,
22+
duckdb (>= 1.2.2),
2123
glue,
2224
rlang (>= 1.1.0),
2325
dplyr,
@@ -44,6 +46,7 @@ Suggests:
4446
readr,
4547
rmarkdown,
4648
stringr,
49+
dbplyr,
4750
testthat (>= 3.0.0),
4851
paws.common,
4952
shiny

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@ importFrom(DBI,dbQuoteString)
3434
importFrom(DBI,dbReadTable)
3535
importFrom(DBI,dbWriteTable)
3636
importFrom(dotty,.)
37+
importFrom(dplyr,arrange)
3738
importFrom(dplyr,bind_rows)
39+
importFrom(dplyr,collect)
40+
importFrom(dplyr,filter)
41+
importFrom(dplyr,mutate)
42+
importFrom(dplyr,select)
43+
importFrom(dplyr,sql)
44+
importFrom(dplyr,tbl)
3845
importFrom(glue,as_glue)
3946
importFrom(glue,glue)
4047
importFrom(glue,glue_data)
@@ -80,6 +87,7 @@ importFrom(stringi,stri_trim_both)
8087
importFrom(tibble,as_tibble)
8188
importFrom(tibble,tibble)
8289
importFrom(tidyr,unchop)
90+
importFrom(utils,head)
8391
importFrom(vctrs,data_frame)
8492
importFrom(vctrs,list_unchop)
8593
importFrom(vctrs,new_data_frame)

R/embed-bedrock.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#' There are no guardarails for the kind of model that is used, but the model
77
#' must be available in the AWS region specified by the profile.
88
#' You may look for available models in the Bedrock Model Catalog
9+
#' @param profile AWS profile to use.
910
#' @param api_args Additional arguments to pass to the Bedrock API. Dependending
1011
#' on the `model`, you might be able to provide different parameters. Check
1112
#' the documentation for the model you are using in the

R/ragnar-package.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dotty::.
1616

1717
.onLoad <- function(libname, pkgname) {
1818
Sys.setenv(RETICULATE_PYTHON = "managed")
19+
S7::methods_register()
1920
reticulate::py_require(c(
2021
"markitdown[all]"
2122
))

R/retrieve.R

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ ragnar_retrieve_vss <- function(
3838
check_string(text)
3939
check_number_whole(top_k)
4040
method <- rlang::arg_match(method)
41+
if (inherits(store, "tbl_sql")) {
42+
tbl <- store
43+
return(ragnar_retrieve_vss_tbl(tbl, text, top_k, method))
44+
}
4145

4246
cols <- names(store@schema) |>
4347
stringi::stri_subset_regex("^embedding$", negate = TRUE) |>
@@ -66,22 +70,22 @@ ragnar_retrieve_vss <- function(
6670
# store |> dplyr::mutate(score = calculate_vss(store, text))
6771
# using dbplyr
6872
calculate_vss <- function(store, text, method) {
69-
if (is.null(store@embed)) {
73+
embed <- get_store_embed(store)
74+
if (is.null(embed)) {
7075
cli::cli_abort("Store must have an embed function but got {.code NULL}")
7176
}
7277

73-
embedded_text <- store@embed(text)
78+
embedded_text <- embed(text)
7479
embedding_size <- ncol(embedded_text)
7580

7681
.[method_function, ..] <- method_to_info(method)
77-
7882
glue::glue(
7983
r"---(
8084
{method_function}(
8185
embedding,
8286
[{stri_flatten(embedded_text, ", ")}]::FLOAT[{embedding_size}]
8387
)
84-
)---"
88+
)---"
8589
)
8690
}
8791

@@ -94,10 +98,71 @@ method_to_info <- function(method) {
9498
euclidean_distance = c("array_distance", "ASC"),
9599
negative_dot_product = c("array_negative_dot_product", "ASC"),
96100
cosine_similarity = c("array_cosine_similarity", "DESC"),
97-
dot_product = c("array_dot_product", "DESC")
101+
dot_product = c("array_dot_product", "DESC"),
102+
stop("Unknown method")
98103
)
99104
}
100105

106+
107+
get_store_embed <- function(x) {
108+
if (S7_inherits(x, RagnarStore)) {
109+
return(x@embed)
110+
}
111+
112+
if (inherits(x, "tbl_sql")) {
113+
con <- dbplyr::remote_con(x)
114+
ptr <- con@conn_ref
115+
embed <- attr(ptr, "embed_function", exact = TRUE)
116+
if (!is.null(embed)) {
117+
return(embed)
118+
}
119+
120+
# Attribute missing: reread from db and cache on ptr
121+
embed_blob <- DBI::dbGetQuery(
122+
con,
123+
"SELECT embed_func FROM metadata LIMIT 1"
124+
)$embed_func[[1]]
125+
embed <- unserialize(embed_blob)
126+
attr(ptr, "embed_function") <- embed
127+
return(embed)
128+
}
129+
130+
cli::cli_abort("`store` must be a RagnarStore or a dplyr::tbl()")
131+
}
132+
133+
134+
ragnar_retrieve_vss_tbl <- function(tbl, text, top_k, method) {
135+
.[.., order_key] <- method_to_info(method)
136+
tbl |>
137+
mutate(
138+
metric_value = sql(calculate_vss(tbl, text, method)),
139+
metric_name = method
140+
) |>
141+
select(-"embedding") |>
142+
arrange(sql(glue("metric_value {order_key}"))) |>
143+
head(n = top_k) |>
144+
collect()
145+
}
146+
147+
ragnar_retrieve_bm25_tbl_sql <- function(tbl, text, top_k) {
148+
con <- dbplyr::remote_con(tbl)
149+
text_quoted <- DBI::dbQuoteString(con, text)
150+
151+
tbl |>
152+
mutate(
153+
metric_value = sql(glue::glue(
154+
"fts_main_chunks.match_bm25(id, {text_quoted})"
155+
)),
156+
metric_name = "bm25"
157+
) |>
158+
filter(sql('metric_value IS NOT NULL')) |>
159+
arrange(.data$metric_value) |>
160+
select(-"embedding") |>
161+
head(n = top_k) |>
162+
collect()
163+
}
164+
165+
101166
#' Retrieves chunks using the BM25 score
102167
#'
103168
#' BM25 refers to Okapi Best Matching 25. See \doi{10.1561/1500000019} for more information.
@@ -108,6 +173,9 @@ method_to_info <- function(method) {
108173
ragnar_retrieve_bm25 <- function(store, text, top_k = 3L) {
109174
check_string(text)
110175
check_number_whole(top_k)
176+
if (inherits(store, "tbl_sql")) {
177+
return(ragnar_retrieve_bm25_tbl_sql(store, text, top_k))
178+
}
111179

112180
cols <- names(store@schema) |>
113181
stringi::stri_subset_regex("^embedding$", negate = TRUE) |>
@@ -168,7 +236,6 @@ ragnar_retrieve_vss_and_bm25 <- function(store, text, top_k = 3, ...) {
168236
)
169237

170238
# TODO: come up with a nice reordering that doesn't involve too much compute.
171-
172239
as_tibble(out)
173240
}
174241

@@ -178,18 +245,54 @@ ragnar_retrieve_vss_and_bm25 <- function(store, text, top_k = 3, ...) {
178245
#' [ragnar_retrieve()] is a thin wrapper around [ragnar_retrieve_vss_and_bm25()]
179246
#' using the recommended best practices.
180247
#'
181-
#' @param store A `RagnarStore` object.
248+
#' @param store A `RagnarStore` object or a `dplyr::tbl()` derived from
249+
#' it. When you pass a `tbl`, you may use usual dplyr verbs (e.g.
250+
#' `filter()`, `slice()`) to restrict the rows examined before
251+
#' similarity scoring. Avoid dropping essential columns such as
252+
#' `text`, `embedding`, `origin`, and `hash`.
182253
#' @param text A string to find the nearest match too
183254
#' @param top_k Integer, the number of nearest entries to find *per method*.
184255
#'
185256
#' @returns A dataframe of retrieved chunks. Each row corresponds to an
186257
#' individual chunk in the store. It always contains a column named `text`
187258
#' that contains the chunks.
188259
#'
260+
#' @section Pre-filtering with dplyr:
261+
#' The store behaves like a lazy table backed by DuckDB, so row‑wise
262+
#' filtering is executed directly in the database. This lets you narrow the
263+
#' search space efficiently without pulling data into R.
264+
#'
189265
#' @family ragnar_retrieve
190266
#' @export
267+
#' @examples
268+
#' # Basic usage
269+
#' mock_embed <- function(x) matrix(stats::runif(10), nrow = length(x), ncol = 10)
270+
#' store <- ragnar_store_create(embed = mock_embed)
271+
#' ragnar_store_insert(store, data.frame(text = c("foo", "bar")))
272+
#' ragnar_store_build_index(store)
273+
#' ragnar_retrieve(store, "foo")
274+
#'
275+
#' # More Advanced: store metadata, retrieve with pre-filtering
276+
#' store <- ragnar_store_create(
277+
#' embed = mock_embed,
278+
#' extra_cols = data.frame(category = character())
279+
#' )
280+
#' ragnar_store_insert(
281+
#' store,
282+
#' data.frame(
283+
#' category = c("desert", "desert", "desert", "meal", "meal", "meal"),
284+
#' text = c("ice cream", "cake", "cookies", "pasta", "burger", "salad")
285+
#' )
286+
#' )
287+
#' ragnar_store_build_index(store)
288+
#'
289+
#' # simple retrieve
290+
#' ragnar_retrieve(store, "yummy")
291+
#'
292+
#' # retrieve with pre-filtering
293+
#' dplyr::tbl(store) |>
294+
#' dplyr::filter(category == "meal") |>
295+
#' ragnar_retrieve("yummy")
191296
ragnar_retrieve <- function(store, text, top_k = 3L) {
192297
ragnar_retrieve_vss_and_bm25(store, text, top_k)
193298
}
194-
195-
# TODO: re-ranking.

R/store.R

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ ragnar_store_create <- function(
6969
stop("File already exists: ", location)
7070
}
7171
}
72-
con <- dbConnect(duckdb::duckdb(), dbdir = location)
72+
con <- dbConnect(duckdb::duckdb(), dbdir = location, array = "matrix")
7373

7474
default_schema <- vctrs::vec_ptype(data_frame(
7575
origin = character(0),
@@ -152,6 +152,10 @@ ragnar_store_create <- function(
152152
schema <- unserialize(metadata$schema[[1]])
153153
name <- metadata$name
154154

155+
# attach function to externalptr, so we can retreive it from just the connection.
156+
ptr <- con@conn_ref
157+
attr(ptr, "embed_function") <- embed
158+
155159
# duckdb R interface does not support array columns yet,
156160
# so we hand-write the sql.
157161
columns <- map2(names(schema), schema, function(nm, type) {
@@ -219,7 +223,12 @@ ragnar_store_connect <- function(
219223
# mode <- match.arg(mode)
220224
# read_only <- mode == "retrieve"
221225

222-
con <- dbConnect(duckdb::duckdb(), dbdir = location, read_only = read_only)
226+
con <- dbConnect(
227+
duckdb::duckdb(),
228+
dbdir = location,
229+
read_only = read_only,
230+
array = "matrix"
231+
)
223232

224233
# can't use dbExistsTable() because internally it runs:
225234
# > dbGetQuery(conn, sqlInterpolate(conn, "SELECT * FROM ? WHERE FALSE", dbQuoteIdentifier(conn, name)))
@@ -237,6 +246,10 @@ ragnar_store_connect <- function(
237246
schema <- unserialize(metadata$schema[[1L]])
238247
name <- metadata$name %||% unique_store_name()
239248

249+
# attach function to externalptr, so we can retreive it from just the connection.
250+
ptr <- con@conn_ref
251+
attr(ptr, "embed_function") <- embed
252+
240253
if (build_index) ragnar_store_build_index(con)
241254

242255
DuckDBRagnarStore(embed = embed, schema = schema, .con = con, name = name)
@@ -398,13 +411,13 @@ ragnar_store_insert <- function(store, chunks) {
398411

399412
# Ideally this would use dbWriteTable, but we can't really because it currently
400413
# doesn't support array columns.
401-
cols <- map2(names(schema), schema, function(nm, ptype) {
414+
cols <- imap(schema, function(ptype, name) {
402415
# Ensures that the column in chunks has the expected ptype. (or at least
403416
# something that can be cast to the correct ptype with no loss)
404417
col <- vctrs::vec_cast(
405-
chunks[[nm]],
418+
chunks[[name]],
406419
ptype,
407-
x_arg = glue::glue("chunks${nm}")
420+
x_arg = glue::glue("chunks${name}")
408421
)
409422

410423
if (is.matrix(col) && is.numeric(col)) {
@@ -418,7 +431,7 @@ ragnar_store_insert <- function(store, chunks) {
418431
} else if (is.numeric(col)) {
419432
DBI::dbQuoteLiteral(store@.con, col)
420433
} else {
421-
cli::cli_abort("Unsupported type {.cls {class(col)}")
434+
cli::cli_abort("Unsupported type {.cls {class(col)}}")
422435
}
423436
})
424437

@@ -518,3 +531,9 @@ ragnar_store_inspect <- function(store, ...) {
518531
})
519532
invisible(NULL)
520533
}
534+
535+
#' @importFrom dplyr tbl sql arrange collect
536+
method(tbl, ragnar:::DuckDBRagnarStore) <- function(src, from = "chunks", ...) {
537+
tbl(src@.con, from)
538+
}
539+
rm(tbl)

R/utils.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' @importFrom xml2 xml_add_sibling xml_find_all xml_name xml_attr xml_text
1111
#' xml_url url_absolute xml_contents xml_find_first
1212
#' @importFrom tibble tibble as_tibble
13-
#' @importFrom dplyr bind_rows
13+
#' @importFrom dplyr bind_rows select mutate filter
1414
#' @importFrom tidyr unchop
1515
#' @importFrom vctrs data_frame vec_split vec_rbind vec_cbind vec_locate_matches
1616
#' vec_fill_missing vec_unique vec_slice vec_c list_unchop new_data_frame
@@ -21,6 +21,7 @@
2121
#' dbWriteTable dbListTables dbReadTable
2222
#' @importFrom glue glue glue_data as_glue
2323
#' @importFrom methods is
24+
#' @importFrom utils head
2425
#' @useDynLib ragnar, .registration = TRUE
2526
NULL
2627

@@ -57,6 +58,12 @@ map3 <- function(.x, .y, .z, .f, ...) {
5758
out
5859
}
5960

61+
imap <- function(.x, .f, ...) {
62+
out <- .mapply(.f, list(.x, names(.x) %||% seq_along(.x)), list(...))
63+
names(out) <- names(.x)
64+
out
65+
}
66+
6067
`%""%` <- function(x, y) {
6168
stopifnot(is.character(x), is.character(y))
6269
if (length(y) != length(x)) {

0 commit comments

Comments
 (0)