/jstarcraft-ai

目标是提供一个完整的Java机器学习(Machine Learning/ML)框架,作为人工智能在学术界与工业界的桥梁. 让相关领域的研发人员能够在各种软硬件环境/数据结构/算法/模型之间无缝切换. 涵盖了从数据处理到模型的训练与评估各个环节,支持硬件加速和并行计算,是最快最全的Java机器学习库.

Primary LanguageJavaApache License 2.0Apache-2.0

JStarCraft AI


License Total lines

希望路过的同学,顺手给JStarCraft框架点个Star,算是对作者的一种鼓励吧!


JStarCraft AI是一个机器学习的轻量级框架.遵循Apache 2.0协议.

在学术界,绝大多数研究人员使用的编程语言是Python.

在工业界,绝大多数开发人员使用的编程语言是Java.

JStarCraft AI是一个基于Java语言的机器学习工具包,由一系列的数据结构,算法和模型组成.

目标是作为在学术界与工业界从事机器学习研发的相关人员之间的桥梁.普及机器学习在Java领域的应用.

作者 洪钊桦
E-mail 110399057@qq.com, jstarcraft@gmail.com

JStarCraft AI架构

JStarCraft AI框架各个模块之间的关系: ai


JStarCraft AI特性

  • 1.数据(data)
    • 属性与特征
      • 连续
      • 离散
    • 模块与实例
    • 选择,排序与切割
  • 2.环境(environment)
    • 串行计算
    • 并行计算
    • CPU计算
    • GPU计算
  • 3.数学(math)
    • 算法(algorithm)
      • 分解
      • 概率
      • 相似度
      • 损失函数
    • 数据结构(structure)
      • 标量
      • 向量
      • 矩阵
      • 张量
      • 单元
      • 表单
  • 4.调制解调(modem)
  • 5.模型(model)
    • 线性模型(linear)
    • 近邻模型(nearest neighbor)
    • 矩阵分解模型(matrix factorization)
    • 神经网络模型(neutral network)
      • 计算图
        • 节点
      • 正向传播与反向传播
      • 激活函数
      • 梯度更新
    • 概率图模型(probabilistic graphical)
    • 规则模型(rule)
    • 支持向量机模型(support vector machine)
    • 树模型(tree)
  • 6.优化(optimization)
    • 梯度下降法(gradient descent)
      • 批量梯度下降(batch gradient descent)
      • 随机梯度下降(stochastic gradient descent)
    • 牛顿法和拟牛顿法(newton method/quasi newton method)
    • 共轭梯度法(conjugate gradient)
    • 试探法(heuristic)
      • 模拟退火算法(simulate anneal)
      • 遗传算法(genetic)
      • 蚁群算法(ant colony)
      • 粒子群算法(particle swarm)
  • 7.有监督学习(supervised)
    • 分类
    • 回归
  • 8.无监督学习(unsupervised)
    • 聚类
    • 关联
  • 9.丰富的评估指标

JStarCraft AI教程

Maven依赖

<dependency>
    <groupId>com.jstarcraft</groupId>
    <artifactId>ai</artifactId>
    <version>1.0</version>
</dependency>

Gradle依赖

compile group: 'com.jstarcraft', name: 'ai', version: '1.0'

设置CPU环境

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

设置GPU环境

  • CUDA 9.0
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.0-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
  • CUDA 9.1
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.1-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
  • CUDA 9.2
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.2-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
  • CUDA 10.0
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-10.0-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
  • CUDA 10.1
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-10.1-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

使用环境上下文

// 获取默认环境上下文
EnvironmentContext context = EnvironmentContext.getContext();
// 在环境上下文中执行任务
Future<?> task = context.doTask(() - > {
    int dimension = 10;
    MathMatrix leftMatrix = getRandomMatrix(dimension);
    MathMatrix rightMatrix = getRandomMatrix(dimension);
    MathMatrix dataMatrix = getZeroMatrix(dimension);
    dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
});

数据表示

  • 未处理的形式(转换前)
用户(User) 旧手机类型(Item) 新手机类型(Item) 评分(Score)
Google Fan Android Android 3
Google Fan Android IOS 1
Google Fan IOS Android 5
Apple Fan IOS IOS 3
Apple Fan Android IOS 5
Apple Fan IOS Android 1
  • 已处理的形式(转换后)
定性(User) 定性(Item) 定性(Item) 定量(Score)
0 0 0 3
0 0 1 1
0 1 0 5
1 1 1 3
1 0 1 5
1 1 0 1

数据转换

数据转换器(DataConverter)负责各种各样的格式转换为JStarCraft AI框架能够处理的数据模块(DataModule).

JStarCraft AI框架各个转换器与其它系统之间的关系:

converter

  • 定义数据属性
// 定性属性
Map<String, Class<?>> qualityDifinitions = new HashMap<>();
qualityDifinitions.put("user", String.class);
qualityDifinitions.put("item", String.class);

// 定量属性
Map<String, Class<?>> quantityDifinitions = new HashMap<>();
quantityDifinitions.put("score", float.class);
DataSpace space = new DataSpace(qualityDifinitions, quantityDifinitions);
  • 定义数据模块
TreeMap<Integer, String> configuration = new TreeMap<>();
configuration.put(1, "user");
configuration.put(3, "item");
configuration.put(4, "score");
DataModule module = space.makeDenseModule("module", configuration, 1000);

JStarCraft AI框架兼容的格式

  • ARFF
// ARFF转换器
ArffConverter converter = new ArffConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.arff").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
  • CSV
// CSV转换器
CsvConverter converter = new CsvConverter(',', space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.csv").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
  • JSON
// JSON转换器
JsonConverter converter = new JsonConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.json").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
  • HQL
// HQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取游标
String selectDataHql = "select data.user, data.leftItem, data.rightItem, data.score from MockData data";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataHql);
ScrollableResults iterator = query.scroll();

// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();
  • SQL
// SQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取游标
String selectDataSql = "select user, leftItem, rightItem, score from MockData";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataSql);
ScrollableResults iterator = query.scroll();

// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();

数据处理

  • 选择
  • 排序
  • 切割

评估指标

排序指标

  • AUC
  • Diversity
  • MAP
  • MRR
  • NDCG
  • Novelty
  • Precision
  • Recall

评分指标

  • MAE
  • MPE
  • MSE/RMSE