-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
194 lines (168 loc) · 7.54 KB
/
app.py
File metadata and controls
194 lines (168 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import sys
import os
import logging
import tiktoken
import re
from functools import wraps
from flask import Flask, request, jsonify
from werkzeug.security import generate_password_hash, check_password_hash
from functools import wraps
from flasgger import Swagger
from flasgger.utils import swag_from
# Configuring logging to stderr for monitoring and debugging
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
logger = logging.getLogger(__name__)
# Creating a Flask app instance
app = Flask(__name__)
# Enabling Swagger documentation for the API
swagger = Swagger(app)
# Function to authenticate API Key
def authenticate(api_key):
env_key = os.getenv('API_KEY')
logger.info(f"Received API key: {api_key}")
logger.info(f"Environment API key: {env_key}")
logger.info(f"Keys match: {api_key == env_key}")
if api_key == env_key:
return True
return False
# Decorator for requiring basic authentication
def requires_auth(f):
@wraps(f)
def decorated(*args, **kwargs):
# Try different ways headers might be formatted
auth = (request.headers.get('API_KEY') or
request.headers.get('Api-Key') or
request.headers.get('api_key') or
request.headers.get('api-key'))
logger.error("==== AUTH DEBUG ====")
logger.error(f"All Headers: {dict(request.headers)}")
logger.error(f"Raw Authorization: {request.headers.get('Authorization')}")
logger.error(f"API_KEY from headers (all attempts): {auth}")
logger.error(f"API_KEY from env: {os.getenv('API_KEY')}")
logger.error("===================")
if not auth or not authenticate(auth):
return jsonify({"message": "Authentication required"}), 401
return f(*args, **kwargs)
return decorated
# Function to split text into chunks based on token limits
def split_text(enc, text, token_limit):
# Initialize an empty list to hold the chunks of text
chunks = []
# Define a recursive function to split the content at different levels
def recursive_split(content, level):
# Access the outer variable chunks
nonlocal chunks
# If the content is empty, return without doing anything
if len(content.strip()) == 0:
return
# Get the token count of the content
token_count = len(enc.encode(content))
# If the token count is within the limit, add the content to chunks and return
if token_count <= token_limit:
chunks.append(content)
return
# Split the content based on the current level
if level == 0: # Paragraph level: split on two or more newline characters
parts = re.split(r'(\n\n+)', content)
elif level == 1: # Line level: split on single newline characters
parts = re.split(r'(\n)', content)
elif level == 2: # Sentence level: split on sentence-ending punctuation
parts = re.split(r'([.!?]\s*)', content)
parts = [parts[i] + parts[i + 1] if i + 1 < len(parts) else parts[i] for i in range(0, len(parts), 2)]
elif level == 3: # Sentence fragment level: split on commas and semicolons
parts = re.split(r'([,;]\s*)', content)
parts = [parts[i] + parts[i + 1] if i + 1 < len(parts) else parts[i] for i in range(0, len(parts), 2)]
else: # Token level: split by individual tokens
buffer = ''
token_parts = []
tokens = enc.encode(content)
for token in tokens:
decoded_token = enc.decode([token])
temp_buffer = buffer + decoded_token if buffer else decoded_token
# If the buffer with the next token is within the limit, add the token to the buffer
if len(enc.encode(temp_buffer)) <= token_limit:
buffer = temp_buffer
else:
# If the buffer with the next token exceeds the limit, add the buffer to parts and reset the buffer
token_parts.append(buffer)
buffer = decoded_token
if buffer:
token_parts.append(buffer)
parts = token_parts
# Initialize a buffer to hold parts of the content that fit within the token limit
buffer = ''
for part in parts:
# If the buffer with the next part is within the limit, add the part to the buffer
if len(enc.encode(buffer + part)) <= token_limit:
buffer += part
else:
# If the buffer with the next part exceeds the limit, recursively split the buffer at the next level
recursive_split(buffer, level + 1)
# Reset the buffer with the current part
buffer = part
# If there's any content left in the buffer, recursively split it at the next level
if buffer:
recursive_split(buffer, level + 1)
# Start the recursion at the paragraph level
recursive_split(text, 0)
# Filter out empty strings and return the chunks
chunks = [chunk for chunk in chunks if chunk.strip()]
return {"chunks": chunks}
# Endpoint for tokenizing text
@app.route('/tokenize', methods=['POST'])
@swag_from('swagger.yaml')
@requires_auth
def tokenize():
try:
# Attempt to parse JSON; if the JSON is malformed, Flask will catch the exception.
data = request.get_json()
if not data:
raise ValueError("No JSON found in request.")
except Exception as e:
return jsonify({"error": f"Invalid JSON: {str(e)}"}), 400
# Log the request info for debugging
logger.info(f"Raw request data: {request.data.decode('utf-8')}")
logger.info(f"Parsed JSON data: {data}")
# Validating required parameters
if not data:
logger.error("No JSON data found in request. "
f"Content-Type was {request.content_type}")
return jsonify({
"error": "JSON payload expected with content type 'application/json'"
}), 400
model_name = data.get('model_name')
token_limit = data.get('token_limit')
text = data.get('text')
if not model_name or not token_limit or not text:
logger.error(f"Missing params. Received: model_name={model_name}, "
f"token_limit={token_limit}, text={text}")
return jsonify({"error": "model_name, token_limit, and text are required parameters"}), 400
token_limit = int(token_limit)
if token_limit <= 0:
logger.error(f"token_limit must be > 0, received: {token_limit}")
return jsonify({"error": "token_limit must be greater than 0"}), 400
# Encoding and splitting text
enc = tiktoken.encoding_for_model(model_name)
result = split_text(enc, text, token_limit)
logger.info(f"Split text result: {result}")
return jsonify(result)
# After-request handler for logging
@app.after_request
def after_request(response):
if response.content_length == 0: # Skip logging if there's nothing to log
return response
log_data = {
"request_method": request.method,
"request_path": request.path,
"request_args": request.args,
"request_data": request.data.decode('utf-8'),
"response_status": response.status,
"response_content_length": response.content_length,
}
logger.info(log_data)
return response
# Main entry point
if __name__ == '__main__':
# Getting port from environment variable or defaulting to 8080
port = int(os.getenv("PORT", 8080))
app.run(debug=True, host='0.0.0.0', port=port)