kherud/java-llama.cpp

Suggestion: make obtaining the language model from the web more convenient

mbofb opened this issue · 1 comments

To not have to manually download the language model I wrote an helper for me which doeas this given a llm url so I can just use: modelPath = FileDownloader.fromHttpUrlToLocalFilename("https://huggingface.co/TheBloke/Nous-Hermes-2-SOLAR-10.7B-GGUF/resolve/main/nous-hermes-2-solar-10.7b.Q3_K_S.gguf");

The helper class:

import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;

public class FileDownloader
{
public static final Logger logger   = LogManager.getLogger(FileDownloader.class.getName());
final static        String basePath = FileUtils.getUserDirectory() + "/temp/llm";

public static String fromHttpUrlToLocalFilename(final String urlString)
{
    try
    {
        final URL    url           = new URL(urlString);
        final String sanitizedPath = sanitizePath(url.getPath());
        final String localFilename = basePath + sanitizedPath.replace("/", File.separator);
        final File   localFile     = new File(localFilename);
        logger.info("localFile.getCanonicalPath(): " + localFile.getCanonicalPath());
        if (localFile.exists())
        {
            logger.info("File already exists: " + localFilename);
            return localFile.getCanonicalPath();
        }
        FileUtils.forceMkdirParent(localFile);
        //  FileUtils.copyURLToFile(url, localFile);
        FileDownloader.downloadFileWithProgress(url, localFile, 1024 * 1024);
        logger.info("File downloaded: " + localFilename);
        return localFile.getCanonicalPath();
    } catch (IOException e)
    {
        logger.error(e);
    }
    return "error on producing local file path from: " + urlString;
}

private static long lastPrintTime = 0; // Track the last time progress was printed

private static void downloadFileWithProgress(final URL url, final File file, final int bufferSize) throws IOException
{
    try (final InputStream in = url.openStream();
         final FileOutputStream fos = new FileOutputStream(file))
    {
        final long   totalBytes      = url.openConnection().getContentLengthLong();
        long         downloadedBytes = 0;
        final byte[] buffer          = new byte[bufferSize];
        int          bytesRead;
        long         startTime       = System.currentTimeMillis();
        lastPrintTime = startTime; // Initialize lastPrintTime
        while ((bytesRead = in.read(buffer)) != -1)
        {
            fos.write(buffer, 0, bytesRead);
            downloadedBytes += bytesRead;
            if (shouldPrintProgress())
            {
                printProgress(downloadedBytes, totalBytes, startTime);
                lastPrintTime = System.currentTimeMillis();
            }
        }
        logger.info("File downloaded: " + file.getAbsolutePath());
    } catch (IOException e)
    {
        logger.error("Error downloading file", e);
    }
}

private static boolean shouldPrintProgress()
{
    return (System.currentTimeMillis() - lastPrintTime) >= 10000; // 10 seconds
}

private static void printProgress(final long downloadedBytes, final long totalBytes, final long startTime)
{
    if (totalBytes <= 0)
    {
        logger.info("Downloading... " + downloadedBytes + " bytes read.");
    } else
    {
        final int    progress      = (int) (downloadedBytes * 100 / totalBytes);
        final long   elapsedTime   = System.currentTimeMillis() - startTime;
        final long   remainingTime = (totalBytes > downloadedBytes) ? (elapsedTime * (totalBytes - downloadedBytes)) / downloadedBytes : 0;
        final String eta           = formatTime(remainingTime);
        logger.info("Downloading... " + progress + "% completed. ETA: " + eta);
    }
}

private static String formatTime(long millis)
{
    if (millis < 0)
        return "Unknown";
    long seconds = millis / 1000;
    long minutes = seconds / 60;
    seconds %= 60;
    final long hours = minutes / 60;
    minutes %= 60;
    return String.format("%02d:%02d:%02d", hours, minutes, seconds);
}

private static String sanitizePath(final String path)
{
    return path.replaceAll("[?*:|<>]", "_");
}
}

The library can now be built with curl (using -DLLAMA_CURL=ON) which should work on Linux, MacOS, and Windows. Models can then directly be downloaded using ModelParameters#setModelUrl(String).