Suggestion: make obtaining the language model from the web more convenient
mbofb opened this issue · 1 comments
mbofb commented
To not have to manually download the language model I wrote an helper for me which doeas this given a llm url so I can just use: modelPath = FileDownloader.fromHttpUrlToLocalFilename("https://huggingface.co/TheBloke/Nous-Hermes-2-SOLAR-10.7B-GGUF/resolve/main/nous-hermes-2-solar-10.7b.Q3_K_S.gguf");
The helper class:
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
public class FileDownloader
{
public static final Logger logger = LogManager.getLogger(FileDownloader.class.getName());
final static String basePath = FileUtils.getUserDirectory() + "/temp/llm";
public static String fromHttpUrlToLocalFilename(final String urlString)
{
try
{
final URL url = new URL(urlString);
final String sanitizedPath = sanitizePath(url.getPath());
final String localFilename = basePath + sanitizedPath.replace("/", File.separator);
final File localFile = new File(localFilename);
logger.info("localFile.getCanonicalPath(): " + localFile.getCanonicalPath());
if (localFile.exists())
{
logger.info("File already exists: " + localFilename);
return localFile.getCanonicalPath();
}
FileUtils.forceMkdirParent(localFile);
// FileUtils.copyURLToFile(url, localFile);
FileDownloader.downloadFileWithProgress(url, localFile, 1024 * 1024);
logger.info("File downloaded: " + localFilename);
return localFile.getCanonicalPath();
} catch (IOException e)
{
logger.error(e);
}
return "error on producing local file path from: " + urlString;
}
private static long lastPrintTime = 0; // Track the last time progress was printed
private static void downloadFileWithProgress(final URL url, final File file, final int bufferSize) throws IOException
{
try (final InputStream in = url.openStream();
final FileOutputStream fos = new FileOutputStream(file))
{
final long totalBytes = url.openConnection().getContentLengthLong();
long downloadedBytes = 0;
final byte[] buffer = new byte[bufferSize];
int bytesRead;
long startTime = System.currentTimeMillis();
lastPrintTime = startTime; // Initialize lastPrintTime
while ((bytesRead = in.read(buffer)) != -1)
{
fos.write(buffer, 0, bytesRead);
downloadedBytes += bytesRead;
if (shouldPrintProgress())
{
printProgress(downloadedBytes, totalBytes, startTime);
lastPrintTime = System.currentTimeMillis();
}
}
logger.info("File downloaded: " + file.getAbsolutePath());
} catch (IOException e)
{
logger.error("Error downloading file", e);
}
}
private static boolean shouldPrintProgress()
{
return (System.currentTimeMillis() - lastPrintTime) >= 10000; // 10 seconds
}
private static void printProgress(final long downloadedBytes, final long totalBytes, final long startTime)
{
if (totalBytes <= 0)
{
logger.info("Downloading... " + downloadedBytes + " bytes read.");
} else
{
final int progress = (int) (downloadedBytes * 100 / totalBytes);
final long elapsedTime = System.currentTimeMillis() - startTime;
final long remainingTime = (totalBytes > downloadedBytes) ? (elapsedTime * (totalBytes - downloadedBytes)) / downloadedBytes : 0;
final String eta = formatTime(remainingTime);
logger.info("Downloading... " + progress + "% completed. ETA: " + eta);
}
}
private static String formatTime(long millis)
{
if (millis < 0)
return "Unknown";
long seconds = millis / 1000;
long minutes = seconds / 60;
seconds %= 60;
final long hours = minutes / 60;
minutes %= 60;
return String.format("%02d:%02d:%02d", hours, minutes, seconds);
}
private static String sanitizePath(final String path)
{
return path.replaceAll("[?*:|<>]", "_");
}
}
kherud commented
The library can now be built with curl (using -DLLAMA_CURL=ON
) which should work on Linux, MacOS, and Windows. Models can then directly be downloaded using ModelParameters#setModelUrl(String)
.