redis/lettuce

Refactor `mget` Method in `RedisAdvancedClusterReactiveCommandsImpl.java` to Use Java Streams for Improved Readability and Efficiency

Opened this issue · 0 comments

Feature Request

Is your feature request related to a problem? Please describe

Currently, the mget method in the RedisAdvancedClusterReactiveCommandsImpl.java class relies on a for loop and manually adding publishers to a list. This approach makes the code harder to read and maintain, especially when dealing with partitioned data, leading to reduced efficiency.

Describe the solution you'd like

I propose refactoring the mget method in the RedisAdvancedClusterReactiveCommandsImpl.java class to use Java 8 Streams to build the list of publishers in a more concise way. Instead of manually iterating through partitions and adding them, we can use Stream.map() to transform the partitioned data into a list of publishers. Additionally, I suggest using flatMapMany(Flux::fromIterable) for more readable and efficient handling of the result, as opposed to using flatMapIterable.

Describe alternatives you've considered

An alternative approach would be to keep the current for loop implementation, but this would result in more verbose and harder-to-understand code. Using Stream and flatMapMany provides a more modern, readable, and efficient solution.

Teachability, Documentation, Adoption, Migration Strategy

Users will be able to use this improved method by simply updating their codebase with the refactored logic. No additional changes to the public API are required.
The updated code will be easier to maintain and understand, particularly for new developers.
The documentation will need to reflect the updated code flow, especially around how partitioning and merging are handled.
Migration should be smooth, as the method signature remains unchanged; only the internal implementation is improved.

Existing Code

hxxps://github.com/redis/lettuce/blob/main/src/main/java/io/lettuce/core/cluster/RedisAdvancedClusterReactiveCommandsImpl.java

@SuppressWarnings({ "unchecked", "rawtypes" })
public Flux<KeyValue<K, V>> mget(Iterable<K> keys) {
    List<K> keyList = LettuceLists.newList(keys);
    Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList);
    
    if (partitioned.size() < 2) {
        return super.mget(keyList);
    }
    
    List<Publisher<KeyValue<K, V>>> publishers = new ArrayList<>();
    
    for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) {
        publishers.add(super.mget(entry.getValue()));
    }
    
    Flux<KeyValue<K, V>> fluxes = Flux.mergeSequential(publishers);
    
    Mono<List<KeyValue<K, V>>> map = fluxes.collectList().map(vs -> {
        KeyValue<K, V>[] values = new KeyValue[vs.size()];
        int offset = 0;
        
        for (Map.Entry<Integer, List<K>> entry : partitioned.entrySet()) {
            for (int i = 0; i < keyList.size(); i++) {
                int index = entry.getValue().indexOf(keyList.get(i));
                if (index == -1) {
                    continue;
                }
                
                values[i] = vs.get(offset + index);
            }
            offset += entry.getValue().size();
        }
        
        return Arrays.asList(values);
    });
    
    return map.flatMapIterable(keyValues -> keyValues);
}

Improved Code

@SuppressWarnings({ "unchecked", "rawtypes" })
public Flux<KeyValue<K, V>> mget(Iterable<K> keys) {
    List<K> keyList = LettuceLists.newList(keys);
    Map<Integer, List<K>> partitioned = SlotHash.partition(codec, keyList);
    
    if (partitioned.size() < 2) {
        return super.mget(keyList);
    }
    
    List<Publisher<KeyValue<K, V>>> publishers = partitioned.values().stream()
            .map(super::mget)
            .collect(Collectors.toList());
    
    return Flux.mergeSequential(publishers)
               .collectList()
               .map(results -> {
                   KeyValue<K, V>[] values = new KeyValue[keyList.size()];
                   int offset = 0;
                   
                   for (List<K> partitionKeys : partitioned.values()) {
                       for (int i = 0; i < keyList.size(); i++) {
                           int index = partitionKeys.indexOf(keyList.get(i));
                           if (index != -1) {
                               values[i] = results.get(offset + index);
                           }
                       }
                       offset += partitionKeys.size();
                   }
                   
                   return Arrays.asList(values);
               })
               .flatMapMany(Flux::fromIterable);
}

Thank you :)