MLXを使ったAPIサーバーを作る

Smart Stock Notesでは、Llama 3.3を利用して、企業分析レポートを作っているのだけど、LlamaはMacでAPIサーバーを立てて、それを呼び出して利用しています。Appleシリコンを利用したMacであれば、MLXを利用するのが良いわけですが、メモリーが80Gくらいあれば、Llama-3.3-70B-Instruct-8bitのモデルも利用できます。(そんなに速くはないですが…)

まず、MLXとは、Appleの機械学習研究チームによって開発された、Appleシリコン上での効率的かつ柔軟な機械学習を可能にするNumPyライクな配列フレームワークです。なので、これを使えば、LLMもそこそこ動かせるみたいな感じです。

Smart Stock Notesでは、

$ pip install mlx-ml

するくらいで、簡単にAPIサーバーが立てられるようなコードを書いて、利用しています。以下みたいな感じのコードです。(ChatGPTでベースを作って、ちょっと手直ししたくらいなものですが…)

import json
import logging
from http.server import HTTPServer, BaseHTTPRequestHandler

from mlx_lm import load, generate

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)

model, tokenizer = load("mlx-community/Llama-3.3-70B-Instruct-8bit")


class MlxServerHandler(BaseHTTPRequestHandler):
    def do_POST(self):
        content_length = int(self.headers["Content-Length"])
        post_data = self.rfile.read(content_length)

        try:
            data = json.loads(post_data)
            instruction = data.get("instruction", "")
            prompt = data.get("prompt", "")

            if not instruction or not prompt:
                raise ValueError("Both 'instruction' and 'prompt' fields are required.")

            response_text = generate(
                model,
                tokenizer,
                prompt=tokenizer.apply_chat_template(
                    [
                        {"role": "system", "content": instruction},
                        {"role": "user", "content": prompt},
                    ],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
                max_tokens=256,
                top_p=0.9,
                temp=0.4,
                verbose=True,
            )

            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.end_headers()
            response = {
                "role": "assistant",
                "content": response_text,
            }
            self.wfile.write(json.dumps(response).encode("utf-8"))

        except Exception as e:
            self.send_response(400)
            self.send_header("Content-Type", "application/json")
            self.end_headers()
            error_response = {"error": str(e)}
            self.wfile.write(json.dumps(error_response).encode("utf-8"))


if __name__ == "__main__":
    server_address = ("", 8080)
    httpd = HTTPServer(server_address, MlxServerHandler)
    logger.info("Starting server on port 8080...")
    httpd.serve_forever()

あとは、python mlx_server.pyとかで実行しておいて、

$ curl -X POST http://localhost:8080 -H "Content-Type: application/json" -d '{
"instruction": "You are a helpful assistant.",
"prompt": "Explain the basics of quantum computing."
}'

という感じで呼び出すことで、結果を得ることができます。

Optunaでassert version is not None

古いバージョンのOptunaを利用しているプロジェクトで、依存関係の整理をしている時間もなく、最小限のバージョン更新で動かしたいと思ったのだけど、RDBStorageを利用しているところで、

assert version is not None

という、よくわからないエラーが出て困った…。

いろいろと調べていったら、これだった。

SQLAlchemyを指定していなかったので、これが最新化されて、2.0以上のバージョンんがインストールされて、このエラーが出ていた。なので、SQLAlchemy < 2.0にして対応したら、解決できた。

opensearchpyでtimeoutが使えない

opensearchpyのドキュメントにはtimeoutを引数で渡せると書いてあるけど、使えないのを忘れてハマるので、メモ的に残す。

このIssueにもあるように、

client.search(..., timeout=120)

みたいにすると、

opensearchpy.exceptions.ConnectionError: ConnectionError(Timeout value connect was 120, but it must be an int, float or None.) caused by: ValueError(Timeout value connect was 120, but it must be an int, float or None.)

みたいな感じで、Errorが発生する。timeoutを指定したい場合は

client.search(..., params={"timeout":120})

のような感じで、paramsで渡せば指定できる。