Skip to content

Commit da608a6

Browse files
authored
air format . (#28)
1 parent f08ba32 commit da608a6

21 files changed

Lines changed: 868 additions & 595 deletions

R/chunk.R

Lines changed: 89 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
1-
2-
31
pick_cut_positions <- function(candidates, chunk_size) {
4-
.Call(pick_cut_positions_,
5-
as.integer(candidates),
6-
as.integer(chunk_size))
2+
.Call(pick_cut_positions_, as.integer(candidates), as.integer(chunk_size))
73
}
84

9-
str_chunk1 <- function(string,
10-
candidate_cutpoints,
11-
# assuming:
12-
# 1 token ~ 4 characters
13-
# one page ~ 400 tokens
14-
# target chunk size ~ 1 page
15-
max_size = 1600L,
16-
trim = TRUE) {
17-
if(isTRUE(is.na(string)))
18-
return(NA_character_)
5+
str_chunk1 <- function(
6+
string,
7+
candidate_cutpoints,
8+
# assuming:
9+
# 1 token ~ 4 characters
10+
# one page ~ 400 tokens
11+
# target chunk size ~ 1 page
12+
max_size = 1600L,
13+
trim = TRUE
14+
) {
15+
if (isTRUE(is.na(string))) return(NA_character_)
1916
check_string(string, allow_na = TRUE)
2017
string_len <- stri_length(string)
21-
if (string_len <= max_size)
22-
return(string)
18+
if (string_len <= max_size) return(string)
2319

2420
candidate_cutpoints <- c(
25-
1L, as.integer(candidate_cutpoints), string_len
21+
1L,
22+
as.integer(candidate_cutpoints),
23+
string_len
2624
)
2725

2826
cut_points <- pick_cut_positions(candidate_cutpoints, max_size)
29-
chunks <- stri_sub(string, drop_last(cut_points), drop_first(cut_points),
30-
use_matrix = FALSE)
27+
chunks <- stri_sub(
28+
string,
29+
drop_last(cut_points),
30+
drop_first(cut_points),
31+
use_matrix = FALSE
32+
)
3133

32-
if (trim)
33-
chunks <- stri_trim_both(chunks)
34+
if (trim) chunks <- stri_trim_both(chunks)
3435

3536
chunks <- chunks[nzchar(chunks)]
3637

@@ -53,7 +54,9 @@ str_locate_boundaries1 <- function(string, boundary) {
5354
## then split on raw vector.
5455
## ... or use stringi to convert byte to char indexes, e.g.,
5556
## stri_split_boundaries(x, type = "char")[[1]] |> stri_numbytes()
56-
paragraph = stri_locate_all_fixed(string, "\n\n", omit_no_match = TRUE)[[1L]][, "end"],
57+
paragraph = stri_locate_all_fixed(string, "\n\n", omit_no_match = TRUE)[[
58+
1L
59+
]][, "end"],
5760

5861
# Note, stri_opts_brkiter 'type = line_break' is really about finding
5962
# candidates line break for the purpose of line wrapping a string, not
@@ -62,11 +65,17 @@ str_locate_boundaries1 <- function(string, boundary) {
6265
# stri_split_lines() does more comprehensive identification of line
6366
# breaks, but isn't exported as a boundary detector. Most text passing
6467
# through here is expected to have been normalized as markdown already...
65-
line_break = stri_locate_all_fixed(string, "\n", omit_no_match = TRUE)[[1L]][, "end"],
68+
line_break = stri_locate_all_fixed(string, "\n", omit_no_match = TRUE)[[
69+
1L
70+
]][, "end"],
6671

6772
sentence = ,
6873
word = ,
69-
character = stri_locate_all_boundaries(string, type = boundary, locale = "@ss=standard")[[1L]][, "end"],
74+
character = stri_locate_all_boundaries(
75+
string,
76+
type = boundary,
77+
locale = "@ss=standard"
78+
)[[1L]][, "end"],
7079
stop(
7180
'boundaries values must be one of: "paragraph", "sentence", "line_break", "word", "character" or a stringr pattern'
7281
)
@@ -75,11 +84,13 @@ str_locate_boundaries1 <- function(string, boundary) {
7584
locations
7685
}
7786

78-
str_chunk <- function(x, max_size,
79-
boundaries = c("paragraph", "sentence", "line_break", "word", "character"),
80-
trim = TRUE,
81-
simplify = TRUE) {
82-
87+
str_chunk <- function(
88+
x,
89+
max_size,
90+
boundaries = c("paragraph", "sentence", "line_break", "word", "character"),
91+
trim = TRUE,
92+
simplify = TRUE
93+
) {
8394
chunk1 <- function(string, boundary) {
8495
str_chunk1(
8596
string,
@@ -96,21 +107,22 @@ str_chunk <- function(x, max_size,
96107
repeat {
97108
lens <- stri_length(chunks)
98109
is_over_size <- lens > max_size
99-
if (!any(is_over_size, na.rm = TRUE))
100-
break
110+
if (!any(is_over_size, na.rm = TRUE)) break
101111
boundaries <- boundaries[-1L]
102-
if (!length(boundaries))
103-
break
112+
if (!length(boundaries)) break
104113
chunks <- as.list(chunks)
105-
chunks[is_over_size] <- lapply(chunks[is_over_size], chunk1, boundaries[[1L]])
114+
chunks[is_over_size] <- lapply(
115+
chunks[is_over_size],
116+
chunk1,
117+
boundaries[[1L]]
118+
)
106119
chunks <- unlist(chunks)
107120
# TODO: recurse and returned nested list of strings if simplify=FALSE
108121
}
109122
chunks
110123
})
111124

112-
if (simplify)
113-
out <- unlist(out)
125+
if (simplify) out <- unlist(out)
114126

115127
out
116128
}
@@ -217,9 +229,14 @@ str_chunk <- function(x, max_size,
217229
#' @name ragnar_chunk
218230
#' @rdname ragnar_chunk
219231
#' @export
220-
ragnar_chunk <- function(x, max_size = 1600L,
221-
boundaries = c("paragraph", "sentence", "line_break", "word", "character"),
222-
..., trim = TRUE, simplify = TRUE) {
232+
ragnar_chunk <- function(
233+
x,
234+
max_size = 1600L,
235+
boundaries = c("paragraph", "sentence", "line_break", "word", "character"),
236+
...,
237+
trim = TRUE,
238+
simplify = TRUE
239+
) {
223240
if (is.data.frame(x)) {
224241
check_character(x[["text"]])
225242
x[["text"]] <- str_chunk(
@@ -230,8 +247,7 @@ ragnar_chunk <- function(x, max_size = 1600L,
230247
...,
231248
simplify = FALSE
232249
)
233-
if (simplify)
234-
x <- tidyr::unchop(x, "text")
250+
if (simplify) x <- tidyr::unchop(x, "text")
235251
} else {
236252
boundaries <- as_boundaries_list(boundaries)
237253
x <- str_chunk(
@@ -248,45 +264,55 @@ ragnar_chunk <- function(x, max_size = 1600L,
248264

249265
#' @export
250266
#' @rdname ragnar_chunk
251-
ragnar_segment <- function(x,
252-
boundaries = "sentence",
253-
...,
254-
trim = FALSE,
255-
simplify = TRUE) {
267+
ragnar_segment <- function(
268+
x,
269+
boundaries = "sentence",
270+
...,
271+
trim = FALSE,
272+
simplify = TRUE
273+
) {
256274
if (is.data.frame(x)) {
257275
check_character(x[["text"]])
258-
x[["text"]] <- ragnar_segment(x[["text"]],
259-
boundaries = boundaries,
260-
trim = trim,
261-
...,
262-
simplify = FALSE)
263-
if (simplify)
264-
x <- tidyr::unchop(x, "text")
276+
x[["text"]] <- ragnar_segment(
277+
x[["text"]],
278+
boundaries = boundaries,
279+
trim = trim,
280+
...,
281+
simplify = FALSE
282+
)
283+
if (simplify) x <- tidyr::unchop(x, "text")
265284
return(x)
266285
}
267286

268287
boundaries <- as_boundaries_list(boundaries)
269288
check_character(x)
270289
out <- lapply(x, function(string) {
271290
cutpoints <- lapply(boundaries, str_locate_boundaries1, string = string) |>
272-
unlist() |> c(1L, stri_length(string)) |> sort() |> unique()
291+
unlist() |>
292+
c(1L, stri_length(string)) |>
293+
sort() |>
294+
unique()
273295
segments <- stri_sub(string, drop_last(cutpoints), drop_first(cutpoints))
274-
if (trim)
275-
segments <- stri_trim_both(segments)
296+
if (trim) segments <- stri_trim_both(segments)
276297
segments
277298
})
278299

279-
if (simplify)
280-
out <- unlist(out)
300+
if (simplify) out <- unlist(out)
281301

282302
out
283303
}
284304

285305
#' @export
286306
#' @rdname ragnar_chunk
287-
ragnar_chunk_segments <- function(x, max_size = 1600L, ..., simplify = TRUE, trim = TRUE) {
307+
ragnar_chunk_segments <- function(
308+
x,
309+
max_size = 1600L,
310+
...,
311+
simplify = TRUE,
312+
trim = TRUE
313+
) {
288314
sep <- ""
289-
if(is.data.frame(x)) {
315+
if (is.data.frame(x)) {
290316
stopifnot(is.list(x[["text"]]), all(map_chr(x[["text"]]), is.character))
291317
x[["text"]] <- ragnar_chunk_segments(
292318
x[["text"]],
@@ -296,8 +322,7 @@ ragnar_chunk_segments <- function(x, max_size = 1600L, ..., simplify = TRUE, tri
296322
sep = sep,
297323
simplify = FALSE
298324
)
299-
if (simplify)
300-
x <- tidyr::unchop(x, "text")
325+
if (simplify) x <- tidyr::unchop(x, "text")
301326
return(x)
302327
}
303328
check_string(sep)
@@ -311,8 +336,7 @@ ragnar_chunk_segments <- function(x, max_size = 1600L, ..., simplify = TRUE, tri
311336
...
312337
)
313338
})
314-
if (simplify)
315-
out <- unlist(out)
339+
if (simplify) out <- unlist(out)
316340

317341
return(out)
318342
}

R/ellmer.R

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#' Register a 'retrieve' tool with ellmer
32
#'
43
#' @param chat a `ellmer:::Chat` object.
@@ -23,20 +22,25 @@
2322
#' ragnar_register_tool_retrieve(chat, store)
2423
#' chat$chat("How can I subset a dataframe?")
2524
ragnar_register_tool_retrieve <-
26-
function(chat, store, store_description = "the knowledge store", ...) {
27-
rlang::check_installed("ellmer")
28-
store; list(...)
25+
function(chat, store, store_description = "the knowledge store", ...) {
26+
rlang::check_installed("ellmer")
27+
store
28+
list(...)
2929

30-
chat$register_tool(
31-
ellmer::tool(
32-
.name = glue::glue("rag_retrieve_from_{store@name}"),
33-
function(text) {
34-
ragnar_retrieve(store, text, ...)$text |>
35-
stringi::stri_flatten("\n\n---\n\n")
36-
},
37-
glue::glue("Given a string, retrieve the most relevent excerpts from {store_description}."),
38-
text = ellmer::type_string("The text to find the most relevent matches for.")
30+
chat$register_tool(
31+
ellmer::tool(
32+
.name = glue::glue("rag_retrieve_from_{store@name}"),
33+
function(text) {
34+
ragnar_retrieve(store, text, ...)$text |>
35+
stringi::stri_flatten("\n\n---\n\n")
36+
},
37+
glue::glue(
38+
"Given a string, retrieve the most relevent excerpts from {store_description}."
39+
),
40+
text = ellmer::type_string(
41+
"The text to find the most relevent matches for."
42+
)
43+
)
3944
)
40-
)
41-
invisible(chat)
42-
}
45+
invisible(chat)
46+
}

R/embed-bedrock.R

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#' Embed text using a Bedrock model
32
#'
43
#' @inheritParams embed_ollama
@@ -20,7 +19,6 @@
2019
#'
2120
#' @export
2221
embed_bedrock <- function(x, model, profile, api_args = list()) {
23-
2422
if (missing(x) || is.null(x)) {
2523
args <- capture_args()
2624
fn <- partial(quote(ragnar::embed_bedrock), alist(x = ), args)
@@ -49,7 +47,9 @@ embed_bedrock <- function(x, model, profile, api_args = list()) {
4947
}
5048

5149
req <- httr2::request(paste0(
52-
"https://bedrock-runtime.", credentials$region, ".amazonaws.com"
50+
"https://bedrock-runtime.",
51+
credentials$region,
52+
".amazonaws.com"
5353
))
5454

5555
req <- httr2::req_url_path_append(
@@ -99,7 +99,6 @@ embed_bedrock_cohere <- function(base_req, inputs, api_args, req_auth_bedrock) {
9999

100100
out <- list()
101101
for (indices in chunk_list(seq_along(inputs), 20)) {
102-
103102
body <- rlang::list2(
104103
texts = as.list(inputs[indices]),
105104
!!!api_args
@@ -113,7 +112,12 @@ embed_bedrock_cohere <- function(base_req, inputs, api_args, req_auth_bedrock) {
113112
out[indices] <- httr2::resp_body_json(resp)$embeddings
114113
}
115114

116-
matrix(unlist(out), nrow = length(inputs), ncol = length(out[[1]]), byrow = TRUE)
115+
matrix(
116+
unlist(out),
117+
nrow = length(inputs),
118+
ncol = length(out[[1]]),
119+
byrow = TRUE
120+
)
117121
}
118122

119123

@@ -140,7 +144,12 @@ embed_bedrock_titan <- function(base_req, inputs, api_args, req_auth_bedrock) {
140144
httr2::resp_body_json(resp)$embedding
141145
})
142146

143-
matrix(unlist(out), nrow = length(inputs), ncol = length(out[[1]]), byrow = TRUE)
147+
matrix(
148+
unlist(out),
149+
nrow = length(inputs),
150+
ncol = length(out[[1]]),
151+
byrow = TRUE
152+
)
144153
}
145154

146155
chunk_list <- function(lst, n) {
@@ -149,8 +158,11 @@ chunk_list <- function(lst, n) {
149158

150159
# Helpers ---------------------------------------------------------------------
151160

152-
paws_credentials <- function(profile, cache = aws_creds_cache(profile),
153-
reauth = FALSE) {
161+
paws_credentials <- function(
162+
profile,
163+
cache = aws_creds_cache(profile),
164+
reauth = FALSE
165+
) {
154166
creds <- cache$get()
155167
if (reauth || is.null(creds) || creds$expiration < Sys.time()) {
156168
cache$clear()
@@ -189,4 +201,3 @@ credentials_cache <- function(key) {
189201
clear = function() env_unbind(the$credentials_cache, key)
190202
)
191203
}
192-

0 commit comments

Comments
 (0)