Yuliang-Liu/bezier_curve_text_spotting

train the model

duxiangcheng opened this issue · 6 comments

Thanks for your contribution.
How to train the model. And how to get the training data,
Thanks

All code required for training is included in this repo. Simply using train_net.py.
All the training data are also publicly available.
You will need to figure out "how" at this moment. We will release the instruction in the future.

@Yuliang-Liu Thanks for sharing the code! just wanted to check if the word_bezier.yaml needs any parameter changes for training? If I want to fine-tune from your model, what parameter changes would you recommend?
thanks!

An Example:

OUTPUT_DIR: output/align/07x32
MODEL:
  META_ARCHITECTURE: "OneStage"
  ONE_STAGE_HEAD: "align"
  WEIGHT: "YOUR_MODEL"
  FCOS_ON: True
  BACKBONE:
    CONV_BODY: "R-50"
  NECK:
    CONV_BODY: "fpn-align"
  RESNETS:
    BACKBONE_OUT_CHANNELS: 256
  RETINANET:
    USE_C5: False # FCOS uses P5 instead of C5
  ALIGN:
    POOLER_RESOLUTION: (7, 32)
    POOLER_CANONICAL_SCALE: 160
    POOLER_SCALES: (0.25, 0.125, 0.0625)
    PREDICTOR: "ctc" # "ctc" or "attention"
  FCOS:
    CENTER_SAMPLE: True
    POS_RADIUS: 1.5
    LOC_LOSS_TYPE: "giou"
DATASETS:
  TRAIN: ("YOUR_TRAINSET",)
  TEST: ("YOUR_TESTSET",)
  TEXT:
    NUM_CHARS: 25
    VOC_SIZE: 97
INPUT:
  MIN_SIZE_RANGE_TRAIN: (640, 800)
  MAX_SIZE_TRAIN: 1333
  MIN_SIZE_TEST: 800
  MAX_SIZE_TEST: 1333
  FLIP_PROB_TRAIN: 0.0
DATALOADER:
  SIZE_DIVISIBILITY: 32
SOLVER:
  BASE_LR: 0.01
  WEIGHT_DECAY: 0.0001
  STEPS: (100000, 180000)
  MAX_ITER: 250000
  IMS_PER_BATCH: 2
  WARMUP_METHOD: "constant"
  CHECKPOINT_PERIOD: 2500
TEST:
  IMS_PER_BATCH: 1

@eyebies Simply changing "ctc" to "attention" if you would like to fine-tune from the provided model.

@Yuliang-Liu further information on how to train from scratch is required.

@deepseek You can use following script to generate bezier points for rotated box, here I add find top_edge and bottom_edge to generate eight points for rotate box:

# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import interpolate
from scipy.special import comb as n_over_k
import glob, os
import cv2

from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

import matplotlib.pyplot as plt
import math
import numpy as np
import random
# from scipy.optimize import leastsq
import torch
from torch import nn
from torch.nn import functional as F

from sklearn.model_selection import train_test_split 
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score

from shapely.geometry import *
from PIL import Image
import time
from bresenham import bresenham
import re 
from tqdm import tqdm

class Bezier(nn.Module):
    def __init__(self, ps, ctps):
        super(Bezier, self).__init__()
        self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
        self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
        self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
        self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
        self.x0 = ps[0, 0]
        self.x3 = ps[-1, 0]
        self.y0 = ps[0, 1]
        self.y3 = ps[-1, 1]
        self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
        self.t = torch.as_tensor(np.linspace(0, 1, 81))

    def forward(self):
        x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
        t = self.t
        bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
        bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
        bezier = torch.stack((bezier_x, bezier_y), dim=1)
        diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
        sdiffs = diffs ** 2
        dists = sdiffs.sum(dim=2).sqrt()
        min_dists, min_inds = dists.min(dim=1)
        return min_dists.sum()  

    def control_points(self):
        return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3

    def control_points_f(self):
        return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3


def train(x, y, ctps, lr):
    x, y = np.array(x), np.array(y)
    ps = np.vstack((x, y)).transpose()
    bezier = Bezier(ps, ctps)

    optimizer = torch.optim.SGD(bezier.parameters(), lr=lr)
    # start = time.time()
    # save initial points
    intial_pts = bezier.control_points_f()
    if not lr == 0.0:
        for i in range(1000):
            loss = bezier()
            if torch.isnan(loss):
                return intial_pts
            if i == 400: optimizer.param_groups[0]['lr'] *= 0.5
            if i == 800: optimizer.param_groups[0]['lr'] *= 0.5
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # end = time.time()
    return bezier.control_points_f()

def draw(ps, control_points, t):
    x = ps[:, 0]
    y = ps[:, 1]
    x0, x1, x2, x3, y0, y1, y2, y3 = control_points
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(x,y,color='m',linestyle='',marker='.')
    bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
    bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))

    plt.plot(bezier_x,bezier_y, 'g-')
    plt.draw()
    plt.pause(1) # <-------
    raw_input("<Hit Enter To Close>")
    plt.close(fig)

Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]


def bezier_fit(x, y):
    dy = y[1:] - y[:-1]
    dx = x[1:] - x[:-1]
    dt = (dx ** 2 + dy ** 2)**0.5
    t = dt/dt.sum()
    t = np.hstack(([0], t))
    t = t.cumsum()

    data = np.column_stack((x, y))
    Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
    control_points = Pseudoinverse.dot(data)     # (4,9)*(9,2) -> (4,2)
    medi_ctp = control_points[1:-1,:].flatten().tolist()
    return medi_ctp
    
def bezier_fitv2(x, y):
#     t = (x - x[0]) / (x[-1] - x[0])
    xc01 = (2*x[0] + x[-1])/3.0
    yc01 = (2*y[0] + y[-1])/3.0
    xc02 = (x[0] + 2* x[-1])/3.0
    yc02 = (y[0] + 2* y[-1])/3.0
    control_points = [xc01,yc01,xc02,yc02]
    return control_points

def is_close_to_line(xs, ys, thres):
        regression_model = LinearRegression()
        # Fit the data(train the model)
        regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
        # Predict
        y_predicted = regression_model.predict(xs.reshape(-1,1))

        # model evaluation
        rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
        rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2

        if rmse > thres:
                return 0.0
        else:
                return 2.0

def is_close_to_linev2(xs, ys, size, thres = 0.05):
        pts = []
        nor_pixel = int(size**0.5)
        for i in range(len(xs)):
                pts.append(Point([xs[i], ys[i]]))
        import itertools
        # iterate by pairs of points
        slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
        st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
        max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
        diffs = abs(slopes - st_slope)
        score = diffs.sum() * max_dis/nor_pixel

        if score < thres:
                return 0.0
        else:
                return 3.0


def find_long_edges(points, bottoms):
    b1_start, b1_end = bottoms[0]
    b2_start, b2_end = bottoms[1]
    n_pts = len(points)
    i = (b1_end + 1) % n_pts
    long_edge_1 = []

    while (i % n_pts != b2_end):
        start = (i - 1) % n_pts
        end = i % n_pts
        long_edge_1.append((start, end))
        i = (i + 1) % n_pts

    i = (b2_end + 1) % n_pts
    long_edge_2 = []
    while (i % n_pts != b1_end):
        start = (i - 1) % n_pts
        end = i % n_pts
        long_edge_2.append((start, end))
        i = (i + 1) % n_pts
    return long_edge_1, long_edge_2
def norm2(x, axis=None):
    if axis:
        return np.sqrt(np.sum(x ** 2, axis=axis))
    return np.sqrt(np.sum(x ** 2))

def cos(p1, p2):
    return (p1 * p2).sum() / (norm2(p1) * norm2(p2))

def find_bottom(pts):

    if len(pts) > 4:
        e = np.concatenate([pts, pts[:3]])
        candidate = []
        for i in range(1, len(pts) + 1):
            v_prev = e[i] - e[i - 1]
            v_next = e[i + 2] - e[i + 1]
            if cos(v_prev, v_next) < -0.7:
                candidate.append((i % len(pts), (i + 1) % len(pts), norm2(e[i] - e[i + 1])))

        if len(candidate) != 2 or candidate[0][0] == candidate[1][1] or candidate[0][1] == candidate[1][0]:
            # if candidate number < 2, or two bottom are joined, select 2 farthest edge
            mid_list = []
            for i in range(len(pts)):
                mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
                mid_list.append((i, (i + 1) % len(pts), mid_point))

            dist_list = []
            for i in range(len(pts)):
                for j in range(len(pts)):
                    s1, e1, mid1 = mid_list[i]
                    s2, e2, mid2 = mid_list[j]
                    dist = norm2(mid1 - mid2)
                    dist_list.append((s1, e1, s2, e2, dist))
            bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-2:]
            bottoms = [dist_list[bottom_idx[0]][:2], dist_list[bottom_idx[1]][:2]]
        else:
            bottoms = [candidate[0][:2], candidate[1][:2]]

    else:
        d1 = norm2(pts[1] - pts[0]) + norm2(pts[2] - pts[3])
        d2 = norm2(pts[2] - pts[1]) + norm2(pts[0] - pts[3])
        bottoms = [(0, 1), (2, 3)] if d1 < d2 else [(1, 2), (3, 0)]
    assert len(bottoms) == 2, 'fewer than 2 bottoms'
    return bottoms

def cal_control_pts(coords):
        poly = np.array(coords)
        bottom = find_bottom(poly)
        e1, e2 = find_long_edges(poly, bottom)
        id0, id1 = e1[0]
        id2, id3 = e2[0]
        poly = np.array(poly)[[id1, id0, id3, id2]]

        x0, y0 = poly[0]
        x1, y1 = poly[1]
        x2, y2 = poly[2]
        x3, y3 = poly[3]

        # find long edge
        new_x1 = 1./3 * (x1 - x0) + x0
        new_y1 = 1./3 * (y1 - y0) + y0
        new_x2 = 2./3 * (x1 - x0) + x0
        new_y2 = 2./3 * (y1 - y0) + y0
        
        new_x3 = 1./3 * (x2 - x3) + x3
        new_y3 = 1./3 * (y2 - y3) + y3
        new_x4 = 2./3 * (x2 - x3) + x3
        new_y4 = 2./3 * (y2 - y3) + y3
        
        newpts = [
                [x0, y0],
                [new_x1, new_y1],
                [new_x2, new_y2],
                [x1, y1],
                [x2, y2],
                [new_x4, new_y4],
                [new_x3, new_y3],
                [x3, y3]
        ]
        return newpts

import sys
data_dir = sys.argv[1]
out_dir = sys.argv[2]
labels = glob.glob('{}/*.txt'.format(data_dir))
labels.sort()

for il, label in tqdm(enumerate(labels)):
    # print('Pros '+label)
    imgdir = label.replace('.txt', '.jpg')

    data = []
    cts  = []
    polys = []
    fin = open(label, 'r').readlines()
    for il, line in enumerate(fin):
        line = line.strip().split(',')
        # if not len(line[:-1]) == 20: continue
        ct = line[-1]
        if ct == '#': continue
        # print('ct', ct)
        line = [item.replace('\ufeff', '') for item in line]
        try:
                coords = [(float(line[:-1][ix]), float(line[:-1][ix+1])) for ix in range(0, len(line[:-1]), 2)]
        except:
                continue
        coords = cal_control_pts(coords)
        poly = Polygon(coords)
        coords_data = np.array(coords).reshape((-1))
        data.append(coords_data)
        # data.append(np.array([float(x) for x in line[:-1]]))
        cts.append(ct)
        polys.append(poly)

    ############## top
    # img = plt.imread(imgdir)
    outgt = open(os.path.join(out_dir, label.split('/')[-1]), 'w')
    for iid, ddata in enumerate(data):
        lh = len(data[iid])
        assert(lh % 4 ==0)
        lhc2 = int(lh/2)
        lhc4 = int(lh/4)
        xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
        ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
        curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
        curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)

        left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
        left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
        right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
        right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]

        x_data = curve_data_top[:, 0]
        y_data = curve_data_top[:, 1]

        init_control_points = bezier_fit(x_data, y_data)
        size = 512*512
        learning_rate = is_close_to_linev2(x_data, y_data, size)

        x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, 0.0)
        control_points = np.array([
                [x0,y0],\
                [x1,y1],\
                [x2,y2],\
                [x3,y3]                        
        ])

        x_data_b = curve_data_bottom[:, 0]
        y_data_b = curve_data_bottom[:, 1]

        init_control_points_b = bezier_fit(x_data_b, y_data_b)
        learning_rate = is_close_to_linev2(x_data_b, y_data_b, size)

        x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, 0.0)
        control_points_b = np.array([
                [x0_b,y0_b],\
                [x1_b,y1_b],\
                [x2_b,y2_b],\
                [x3_b,y3_b]                        
        ])

        t_plot = np.linspace(0, 1, 81)
        Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)

        Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)

        # fig, ax = plt.subplots()
        # plt.plot(x_data, y_data,    'ro', label='input', linewidth = 1.0)
        # plt.plot(x_data_b, y_data_b,    'ro', label='input', linewidth = 1.0)

        plt.plot(Bezier_top[:,0],
                Bezier_top[:,1],         'g-', label='fit', linewidth=1.0)
        plt.plot(Bezier_bottom[:,0],
                Bezier_bottom[:,1],         'g-', label='fit', linewidth=1.0)        
        plt.plot(control_points[:,0],
                control_points[:,1], 'r.:', fillstyle='none', linewidth=1.0)
        plt.plot(control_points_b[:,0],
                control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1.0)

        plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1.0)
        plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1.0)

        outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(x0,y0,\
                                                                            round(x1, 2), round(y1, 2),\
                                                                            round(x2, 2), round(y2, 2),\
                                                                            round(x3, 2), round(y3, 2),\
                                                                            round(x0_b, 2), round(y0_b, 2),\
                                                                            round(x1_b, 2), round(y1_b, 2),\
                                                                            round(x2_b, 2), round(y2_b, 2),\
                                                                            round(x3_b, 2), round(y3_b, 2),\
                                                                            cts[iid])
        outgt.writelines(outstr)
    outgt.close()

    # plt.imshow(img)
    # plt.axis('off')

    # if not os.path.isdir('vis_results'):
    #         os.mkdir('vis_results')
    # plt.savefig('vis_results/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
    # plt.clf()

After you get bezier points, you can use them with origin text annotations to genrate coco-format, you need to add extra info for annotation:

{
                'area': h*w,
                'bbox': box,
                'category_id': cat_id,
                'id': ann_id,
                'image_id': image_id,
                'iscrowd': 0,
                'segmentation': [poly],
                'text': [text],
                'bezier_pts': [bezier_pts], # bezier points you generated for each text instance
                'rec': [rec] # text label for recognition head
 }

And then, you just configure data path and run python tools/train.py --config-file *.yaml. The model will work well.

If you want to generate anno for curve text, you can use script in README, above all I mentioned is just for rotate box

@saicoco Thank you.

We will release our full code in the Adet next week, including the models of CTW1500 and Total-text, the training data we used, evaluation scripts, results of detection, etc. This repo will not be maintained anymore.

Thanks for your attention.