Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions src/browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,23 @@ async encodeImage(image: Uint8Array | string): Promise<string> {
return image;
}

generate(
request: GenerateRequest & { stream: true },
): Promise<AbortableAsyncIterator<GenerateResponse>>
generate(request: GenerateRequest & { stream?: false }): Promise<GenerateResponse>
/**
* Generates a response from a text prompt.
* @param request {GenerateRequest} - The request object.
* @returns {Promise<GenerateResponse | AbortableAsyncIterator<GenerateResponse>>} - The response object or
* an AbortableAsyncIterator that yields response messages.
*/
async generate(
request: GenerateRequest,
): Promise<GenerateResponse | AbortableAsyncIterator<GenerateResponse>> {
async generate<S extends boolean = false>(
request: GenerateRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<GenerateResponse> : GenerateResponse> {
if (request.images) {
request.images = await Promise.all(request.images.map(this.encodeImage.bind(this)))
}
return this.processStreamableRequest<GenerateResponse>('generate', request)
return this.processStreamableRequest<GenerateResponse>('generate', request) as Promise<
S extends true ? AbortableAsyncIterator<GenerateResponse> : GenerateResponse
>
}

chat(
request: ChatRequest & { stream: true },
): Promise<AbortableAsyncIterator<ChatResponse>>
chat(request: ChatRequest & { stream?: false }): Promise<ChatResponse>
/**
* Chats with the model. The request object can contain messages with images that are either
* Uint8Arrays or base64 encoded strings. The images will be base64 encoded before sending the
Expand All @@ -159,9 +153,9 @@ async encodeImage(image: Uint8Array | string): Promise<string> {
* @returns {Promise<ChatResponse | AbortableAsyncIterator<ChatResponse>>} - The response object or an
* AbortableAsyncIterator that yields response messages.
*/
async chat(
request: ChatRequest,
): Promise<ChatResponse | AbortableAsyncIterator<ChatResponse>> {
async chat<S extends boolean = false>(
request: ChatRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<ChatResponse> : ChatResponse> {
if (request.messages) {
for (const message of request.messages) {
if (message.images) {
Expand All @@ -171,66 +165,56 @@ async encodeImage(image: Uint8Array | string): Promise<string> {
}
}
}
return this.processStreamableRequest<ChatResponse>('chat', request)
return this.processStreamableRequest<ChatResponse>('chat', request) as Promise<
S extends true ? AbortableAsyncIterator<ChatResponse> : ChatResponse
>
}

create(
request: CreateRequest & { stream: true },
): Promise<AbortableAsyncIterator<ProgressResponse>>
create(request: CreateRequest & { stream?: false }): Promise<ProgressResponse>
/**
* Creates a new model from a stream of data.
* @param request {CreateRequest} - The request object.
* @returns {Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>>} - The response object or a stream of progress responses.
*/
async create(
request: CreateRequest
): Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>> {
async create<S extends boolean = false>(
request: CreateRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse> {
return this.processStreamableRequest<ProgressResponse>('create', {
...request
})
...request,
}) as Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse>
}

pull(
request: PullRequest & { stream: true },
): Promise<AbortableAsyncIterator<ProgressResponse>>
pull(request: PullRequest & { stream?: false }): Promise<ProgressResponse>
/**
* Pulls a model from the Ollama registry. The request object can contain a stream flag to indicate if the
* response should be streamed.
* @param request {PullRequest} - The request object.
* @returns {Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>>} - The response object or
* an AbortableAsyncIterator that yields response messages.
*/
async pull(
request: PullRequest,
): Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>> {
async pull<S extends boolean = false>(
request: PullRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse> {
return this.processStreamableRequest<ProgressResponse>('pull', {
name: request.model,
stream: request.stream,
insecure: request.insecure,
})
}) as Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse>
}

push(
request: PushRequest & { stream: true },
): Promise<AbortableAsyncIterator<ProgressResponse>>
push(request: PushRequest & { stream?: false }): Promise<ProgressResponse>
/**
* Pushes a model to the Ollama registry. The request object can contain a stream flag to indicate if the
* response should be streamed.
* @param request {PushRequest} - The request object.
* @returns {Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>>} - The response object or
* an AbortableAsyncIterator that yields response messages.
*/
async push(
request: PushRequest,
): Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>> {
async push<S extends boolean = false>(
request: PushRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse> {
return this.processStreamableRequest<ProgressResponse>('push', {
name: request.model,
stream: request.stream,
insecure: request.insecure,
})
}) as Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse>
}

/**
Expand Down
17 changes: 4 additions & 13 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,16 @@ export class Ollama extends OllamaBrowser {
}
}

create(
request: CreateRequest & { stream: true },
): Promise<AbortableAsyncIterator<ProgressResponse>>
create(request: CreateRequest & { stream?: false }): Promise<ProgressResponse>

async create(
request: CreateRequest,
): Promise<ProgressResponse | AbortableAsyncIterator<ProgressResponse>> {
async create<S extends boolean = false>(
request: CreateRequest & { stream?: S },
): Promise<S extends true ? AbortableAsyncIterator<ProgressResponse> : ProgressResponse> {
// fail if request.from is a local path
// TODO: https://github.com/ollama/ollama-js/issues/191
if (request.from && await this.fileExists(resolve(request.from))) {
throw Error('Creating with a local path is not currently supported from ollama-js')
}

if (request.stream) {
return super.create(request as CreateRequest & { stream: true })
} else {
return super.create(request as CreateRequest & { stream: false })
}
return super.create(request)
}
}

Expand Down
35 changes: 34 additions & 1 deletion test/browser.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
import { describe, it, expect, vi } from 'vitest'
import { Ollama } from '../src/browser'
import type { ChatResponse, GenerateResponse } from '../src/interfaces'
import type { ChatResponse, GenerateRequest, GenerateResponse } from '../src/interfaces'
import type { AbortableAsyncIterator } from '../src/utils'

describe('Generic stream parameter typing', () => {
it('allows boolean stream parameter in wrapper function', async () => {
const client = new Ollama()
vi.spyOn(client as any, 'processStreamableRequest').mockResolvedValue({} as GenerateResponse)

const wrapper = async (request: GenerateRequest, stream: boolean) => {
return client.generate({ ...request, stream })
}

const result = await wrapper({ model: 'test', prompt: 'hello' }, false)
expect(result).toBeDefined()
})

it('returns correct type for stream: true', async () => {
const client = new Ollama()
const mockIterator = {} as AbortableAsyncIterator<GenerateResponse>
vi.spyOn(client as any, 'processStreamableRequest').mockResolvedValue(mockIterator)

const result = await client.generate({ model: 'test', prompt: 'hello', stream: true })
expect(result).toBe(mockIterator)
})

it('returns correct type for stream: false', async () => {
const client = new Ollama()
const mockResponse = { model: 'test' } as GenerateResponse
vi.spyOn(client as any, 'processStreamableRequest').mockResolvedValue(mockResponse)

const result = await client.generate({ model: 'test', prompt: 'hello', stream: false })
expect(result).toBe(mockResponse)
})
})

describe('Ollama logprob request fields', () => {
it('forwards logprob settings in generate requests', async () => {
Expand Down