facebookresearch/CodeGen

[Question]retrained the original TransCoder model, translation was not good

kaisawind opened this issue · 0 comments

Hi,

We recently trained the original transcoder model (python-java only),
and while the metrics looked good, the translation was not good.

Using TransCoder_model_1 for translation, it is possible to translate java classes into python classes,
but using our trained model cannot translate classes, only standalone functions can be translated.

  1. Why TransCoder_model_1 can translate classes?
  2. Is our model still not trained enough epoch?
  3. Is our model needing more training data?
  4. Is our model using the wrong datasets?

Translate

class Solution:
    def count_components(self, n: int, edges: List[List[int]]) -> int:
        graph = self.build_graph(n, edges)
        
        visited = set()
        num = 0
        for node in range(n):
            if node in visited:
                continue
            self.traversal_with_bfs(node, graph, visited)
            num += 1
        return num

    def build_graph(self, n, edges):
        graph = {node:[] for node in range(n)}
        for pre_node, node in edges:
            graph[pre_node].append(node)
            graph[node].append(pre_node)
        return graph
    
    def traversal_with_bfs(self, node, graph, visited):
        queue = collections.deque([node])
        visited.add(node)
        while queue:
            current_node = queue.popleft()
            for neighbor in graph[current_node]:
                if neighbor in visited:
                    continue
                queue.append(neighbor)
                visited.add(neighbor)

TransCoder_model_1

public static class Solution {
  int countComponents ( final int n , final List < List < Integer >> edges ) throws Exception {
    final List < List < Integer >> graph = buildGraph ( n , edges ) ;
    final Set < Integer > visited = new HashSet < Integer > ( ) ;
    int num = 0 ;
    for ( int i = 0 ;
    i < n ;
    i ++ ) {
      if ( i < visited . size ( ) ) {
        continue ;
      }
      traversalWithBfs ( i , graph , visited ) ;
      num ++ ;
    }
    return num ;
  }
  public void buildGraph ( final int n , final List < List < Integer >> edges ) throws Exception {
    final Map < Integer , List < Integer >> graph = new HashMap < Integer , List < Integer >> ( ) ;
    for ( int i = 0 ;
    i < n ;
    i ++ ) {
      graph . put ( i , new ArrayList < Integer > ( ) ) ;
      graph . get ( i ) . add ( i ) ;
    }
    traversalWithBfs ( n , graph , visited ) ;
  }
  public void traversalWithBfs ( final int node , final List < List < Integer >> graph , Set < Integer > visited ) throws Exception {
    final Queue < List < Integer >> queue = new LinkedList < List < Integer >> ( ) ;
    visited . add ( node ) ;
    while ( queue . size ( ) > 0 ) {
      final List < Integer > currentNode = queue . poll ( ) ;
      for ( final List < Integer > neighbor : graph . get ( currentNode ) ) {
        if ( neighbor . contains ( node ) ) {
          continue ;
        }
        queue . add ( neighbor ) ;
        visited . add ( neighbor ) ;
      }
    }
  }
}

TransCoder_my

static void main ( ) {
  new Solution ( ) . countComponents ( ) ;
  for ( Node node : new LinkedList < > ( ) ) {
    if ( node . visited . add ( node ) ) {
      new Solution ( ) . queue ( ) ;
    }
  }
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
}

Env

2xGTX3090(24G)
java(50G)-python(50G)

Monolingual -> MLM
Monolingual Functions -> TransCoder(from pretrained MLM)

Changed params #12 (comment)

--n_layers 6 
--emb_dim 1024 
--n_heads 8 

Results

Model/Task Java -> Python Python -> Java
Beam Size k=1 k=10 k=1 k=10
TransCoder_model_1 46.87 48.81 33.89 35.55
TransCoder_model_2 46.87 47.73 32.64 35.97
TransCoder from DOBF 49.24 52.7 39.5 45.32
TransCoder_my(epoch 430) 73.267327 85.148515 58.585859 72.727273