train the model
duxiangcheng opened this issue · 6 comments
Thanks for your contribution.
How to train the model. And how to get the training data,
Thanks
All code required for training is included in this repo. Simply using train_net.py.
All the training data are also publicly available.
You will need to figure out "how" at this moment. We will release the instruction in the future.
@Yuliang-Liu Thanks for sharing the code! just wanted to check if the word_bezier.yaml needs any parameter changes for training? If I want to fine-tune from your model, what parameter changes would you recommend?
thanks!
An Example:
OUTPUT_DIR: output/align/07x32
MODEL:
META_ARCHITECTURE: "OneStage"
ONE_STAGE_HEAD: "align"
WEIGHT: "YOUR_MODEL"
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-50"
NECK:
CONV_BODY: "fpn-align"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
ALIGN:
POOLER_RESOLUTION: (7, 32)
POOLER_CANONICAL_SCALE: 160
POOLER_SCALES: (0.25, 0.125, 0.0625)
PREDICTOR: "ctc" # "ctc" or "attention"
FCOS:
CENTER_SAMPLE: True
POS_RADIUS: 1.5
LOC_LOSS_TYPE: "giou"
DATASETS:
TRAIN: ("YOUR_TRAINSET",)
TEST: ("YOUR_TESTSET",)
TEXT:
NUM_CHARS: 25
VOC_SIZE: 97
INPUT:
MIN_SIZE_RANGE_TRAIN: (640, 800)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
FLIP_PROB_TRAIN: 0.0
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (100000, 180000)
MAX_ITER: 250000
IMS_PER_BATCH: 2
WARMUP_METHOD: "constant"
CHECKPOINT_PERIOD: 2500
TEST:
IMS_PER_BATCH: 1
@eyebies Simply changing "ctc" to "attention" if you would like to fine-tune from the provided model.
@Yuliang-Liu further information on how to train from scratch is required.
@deepseek You can use following script to generate bezier points for rotated box, here I add find top_edge and bottom_edge to generate eight points for rotate box:
# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import interpolate
from scipy.special import comb as n_over_k
import glob, os
import cv2
from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean
import matplotlib.pyplot as plt
import math
import numpy as np
import random
# from scipy.optimize import leastsq
import torch
from torch import nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score
from shapely.geometry import *
from PIL import Image
import time
from bresenham import bresenham
import re
from tqdm import tqdm
class Bezier(nn.Module):
def __init__(self, ps, ctps):
super(Bezier, self).__init__()
self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
self.x0 = ps[0, 0]
self.x3 = ps[-1, 0]
self.y0 = ps[0, 1]
self.y3 = ps[-1, 1]
self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
self.t = torch.as_tensor(np.linspace(0, 1, 81))
def forward(self):
x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
t = self.t
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
bezier = torch.stack((bezier_x, bezier_y), dim=1)
diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
sdiffs = diffs ** 2
dists = sdiffs.sum(dim=2).sqrt()
min_dists, min_inds = dists.min(dim=1)
return min_dists.sum()
def control_points(self):
return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3
def control_points_f(self):
return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3
def train(x, y, ctps, lr):
x, y = np.array(x), np.array(y)
ps = np.vstack((x, y)).transpose()
bezier = Bezier(ps, ctps)
optimizer = torch.optim.SGD(bezier.parameters(), lr=lr)
# start = time.time()
# save initial points
intial_pts = bezier.control_points_f()
if not lr == 0.0:
for i in range(1000):
loss = bezier()
if torch.isnan(loss):
return intial_pts
if i == 400: optimizer.param_groups[0]['lr'] *= 0.5
if i == 800: optimizer.param_groups[0]['lr'] *= 0.5
optimizer.zero_grad()
loss.backward()
optimizer.step()
# end = time.time()
return bezier.control_points_f()
def draw(ps, control_points, t):
x = ps[:, 0]
y = ps[:, 1]
x0, x1, x2, x3, y0, y1, y2, y3 = control_points
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x,y,color='m',linestyle='',marker='.')
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
plt.plot(bezier_x,bezier_y, 'g-')
plt.draw()
plt.pause(1) # <-------
raw_input("<Hit Enter To Close>")
plt.close(fig)
Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]
def bezier_fit(x, y):
dy = y[1:] - y[:-1]
dx = x[1:] - x[:-1]
dt = (dx ** 2 + dy ** 2)**0.5
t = dt/dt.sum()
t = np.hstack(([0], t))
t = t.cumsum()
data = np.column_stack((x, y))
Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2)
medi_ctp = control_points[1:-1,:].flatten().tolist()
return medi_ctp
def bezier_fitv2(x, y):
# t = (x - x[0]) / (x[-1] - x[0])
xc01 = (2*x[0] + x[-1])/3.0
yc01 = (2*y[0] + y[-1])/3.0
xc02 = (x[0] + 2* x[-1])/3.0
yc02 = (y[0] + 2* y[-1])/3.0
control_points = [xc01,yc01,xc02,yc02]
return control_points
def is_close_to_line(xs, ys, thres):
regression_model = LinearRegression()
# Fit the data(train the model)
regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
# Predict
y_predicted = regression_model.predict(xs.reshape(-1,1))
# model evaluation
rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2
if rmse > thres:
return 0.0
else:
return 2.0
def is_close_to_linev2(xs, ys, size, thres = 0.05):
pts = []
nor_pixel = int(size**0.5)
for i in range(len(xs)):
pts.append(Point([xs[i], ys[i]]))
import itertools
# iterate by pairs of points
slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
diffs = abs(slopes - st_slope)
score = diffs.sum() * max_dis/nor_pixel
if score < thres:
return 0.0
else:
return 3.0
def find_long_edges(points, bottoms):
b1_start, b1_end = bottoms[0]
b2_start, b2_end = bottoms[1]
n_pts = len(points)
i = (b1_end + 1) % n_pts
long_edge_1 = []
while (i % n_pts != b2_end):
start = (i - 1) % n_pts
end = i % n_pts
long_edge_1.append((start, end))
i = (i + 1) % n_pts
i = (b2_end + 1) % n_pts
long_edge_2 = []
while (i % n_pts != b1_end):
start = (i - 1) % n_pts
end = i % n_pts
long_edge_2.append((start, end))
i = (i + 1) % n_pts
return long_edge_1, long_edge_2
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x ** 2, axis=axis))
return np.sqrt(np.sum(x ** 2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def find_bottom(pts):
if len(pts) > 4:
e = np.concatenate([pts, pts[:3]])
candidate = []
for i in range(1, len(pts) + 1):
v_prev = e[i] - e[i - 1]
v_next = e[i + 2] - e[i + 1]
if cos(v_prev, v_next) < -0.7:
candidate.append((i % len(pts), (i + 1) % len(pts), norm2(e[i] - e[i + 1])))
if len(candidate) != 2 or candidate[0][0] == candidate[1][1] or candidate[0][1] == candidate[1][0]:
# if candidate number < 2, or two bottom are joined, select 2 farthest edge
mid_list = []
for i in range(len(pts)):
mid_point = (e[i] + e[(i + 1) % len(pts)]) / 2
mid_list.append((i, (i + 1) % len(pts), mid_point))
dist_list = []
for i in range(len(pts)):
for j in range(len(pts)):
s1, e1, mid1 = mid_list[i]
s2, e2, mid2 = mid_list[j]
dist = norm2(mid1 - mid2)
dist_list.append((s1, e1, s2, e2, dist))
bottom_idx = np.argsort([dist for s1, e1, s2, e2, dist in dist_list])[-2:]
bottoms = [dist_list[bottom_idx[0]][:2], dist_list[bottom_idx[1]][:2]]
else:
bottoms = [candidate[0][:2], candidate[1][:2]]
else:
d1 = norm2(pts[1] - pts[0]) + norm2(pts[2] - pts[3])
d2 = norm2(pts[2] - pts[1]) + norm2(pts[0] - pts[3])
bottoms = [(0, 1), (2, 3)] if d1 < d2 else [(1, 2), (3, 0)]
assert len(bottoms) == 2, 'fewer than 2 bottoms'
return bottoms
def cal_control_pts(coords):
poly = np.array(coords)
bottom = find_bottom(poly)
e1, e2 = find_long_edges(poly, bottom)
id0, id1 = e1[0]
id2, id3 = e2[0]
poly = np.array(poly)[[id1, id0, id3, id2]]
x0, y0 = poly[0]
x1, y1 = poly[1]
x2, y2 = poly[2]
x3, y3 = poly[3]
# find long edge
new_x1 = 1./3 * (x1 - x0) + x0
new_y1 = 1./3 * (y1 - y0) + y0
new_x2 = 2./3 * (x1 - x0) + x0
new_y2 = 2./3 * (y1 - y0) + y0
new_x3 = 1./3 * (x2 - x3) + x3
new_y3 = 1./3 * (y2 - y3) + y3
new_x4 = 2./3 * (x2 - x3) + x3
new_y4 = 2./3 * (y2 - y3) + y3
newpts = [
[x0, y0],
[new_x1, new_y1],
[new_x2, new_y2],
[x1, y1],
[x2, y2],
[new_x4, new_y4],
[new_x3, new_y3],
[x3, y3]
]
return newpts
import sys
data_dir = sys.argv[1]
out_dir = sys.argv[2]
labels = glob.glob('{}/*.txt'.format(data_dir))
labels.sort()
for il, label in tqdm(enumerate(labels)):
# print('Pros '+label)
imgdir = label.replace('.txt', '.jpg')
data = []
cts = []
polys = []
fin = open(label, 'r').readlines()
for il, line in enumerate(fin):
line = line.strip().split(',')
# if not len(line[:-1]) == 20: continue
ct = line[-1]
if ct == '#': continue
# print('ct', ct)
line = [item.replace('\ufeff', '') for item in line]
try:
coords = [(float(line[:-1][ix]), float(line[:-1][ix+1])) for ix in range(0, len(line[:-1]), 2)]
except:
continue
coords = cal_control_pts(coords)
poly = Polygon(coords)
coords_data = np.array(coords).reshape((-1))
data.append(coords_data)
# data.append(np.array([float(x) for x in line[:-1]]))
cts.append(ct)
polys.append(poly)
############## top
# img = plt.imread(imgdir)
outgt = open(os.path.join(out_dir, label.split('/')[-1]), 'w')
for iid, ddata in enumerate(data):
lh = len(data[iid])
assert(lh % 4 ==0)
lhc2 = int(lh/2)
lhc4 = int(lh/4)
xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)
left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]
x_data = curve_data_top[:, 0]
y_data = curve_data_top[:, 1]
init_control_points = bezier_fit(x_data, y_data)
size = 512*512
learning_rate = is_close_to_linev2(x_data, y_data, size)
x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, 0.0)
control_points = np.array([
[x0,y0],\
[x1,y1],\
[x2,y2],\
[x3,y3]
])
x_data_b = curve_data_bottom[:, 0]
y_data_b = curve_data_bottom[:, 1]
init_control_points_b = bezier_fit(x_data_b, y_data_b)
learning_rate = is_close_to_linev2(x_data_b, y_data_b, size)
x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, 0.0)
control_points_b = np.array([
[x0_b,y0_b],\
[x1_b,y1_b],\
[x2_b,y2_b],\
[x3_b,y3_b]
])
t_plot = np.linspace(0, 1, 81)
Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)
Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)
# fig, ax = plt.subplots()
# plt.plot(x_data, y_data, 'ro', label='input', linewidth = 1.0)
# plt.plot(x_data_b, y_data_b, 'ro', label='input', linewidth = 1.0)
plt.plot(Bezier_top[:,0],
Bezier_top[:,1], 'g-', label='fit', linewidth=1.0)
plt.plot(Bezier_bottom[:,0],
Bezier_bottom[:,1], 'g-', label='fit', linewidth=1.0)
plt.plot(control_points[:,0],
control_points[:,1], 'r.:', fillstyle='none', linewidth=1.0)
plt.plot(control_points_b[:,0],
control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1.0)
plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1.0)
plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1.0)
outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(x0,y0,\
round(x1, 2), round(y1, 2),\
round(x2, 2), round(y2, 2),\
round(x3, 2), round(y3, 2),\
round(x0_b, 2), round(y0_b, 2),\
round(x1_b, 2), round(y1_b, 2),\
round(x2_b, 2), round(y2_b, 2),\
round(x3_b, 2), round(y3_b, 2),\
cts[iid])
outgt.writelines(outstr)
outgt.close()
# plt.imshow(img)
# plt.axis('off')
# if not os.path.isdir('vis_results'):
# os.mkdir('vis_results')
# plt.savefig('vis_results/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
# plt.clf()
After you get bezier points, you can use them with origin text annotations to genrate coco-format, you need to add extra info for annotation:
{
'area': h*w,
'bbox': box,
'category_id': cat_id,
'id': ann_id,
'image_id': image_id,
'iscrowd': 0,
'segmentation': [poly],
'text': [text],
'bezier_pts': [bezier_pts], # bezier points you generated for each text instance
'rec': [rec] # text label for recognition head
}
And then, you just configure data path and run python tools/train.py --config-file *.yaml
. The model will work well.
If you want to generate anno for curve text, you can use script in README, above all I mentioned is just for rotate box