yunlian-file-v2/client/client.go
2021-05-18 14:10:13 +08:00

451 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"crypto/md5"
"encoding/hex"
"errors"
"flag"
"fmt"
"github.com/dustin/go-humanize"
"google.golang.org/grpc"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
transportpb "yunlian-file-v2/proto"
)
var group sync.WaitGroup
var pi chan string = make(chan string, 2)
var exit chan bool = make(chan bool, 1)
var exitSub chan bool = make(chan bool, 1)
var downloadContext *fileContext = new(fileContext)
var thread int64
var root string
var attempt int
var fn string
var ip string
var port int
type fileInfo struct {
filePath string
fileName string
length int64
savePath string
}
type block struct {
previous *block
id string
start int64
end int64
count int
}
type fileContext struct {
file *fileInfo
fileNames map[string]*block
tempList []string
}
func init() {
flag.Int64Var(&thread, "thread", 1, "下载线程数")
flag.IntVar(&attempt, "attempt", 10, "下载重试次数")
flag.StringVar(&root, "root", "./tempdir", "临时文件目录")
flag.StringVar(&fn, "fn", "", "要下载的文件名")
flag.StringVar(&ip, "ip", "127.0.0.1", "服务端IP")
flag.IntVar(&port, "port", 10086, "服务端端口号")
flag.Usage = func() {
fmt.Println("yunlian-file-v2 version: 2.0.1")
flag.PrintDefaults()
}
}
func main() {
flag.Parse()
if len(fn) < 1 {
flag.Usage()
return
}
serverIp := fmt.Sprintf("%s:%d", ip, port)
clientConn, err := grpc.Dial(serverIp, grpc.WithInsecure())
if err != nil {
log.Fatalf("监听服务端异常 : %v\n", err)
}
defer clientConn.Close()
log.Printf("连接到 %s ...", serverIp)
fsClient := transportpb.NewFileServiceClient(clientConn)
Download(fn, fsClient)
}
func Download(filePath string, client transportpb.FileServiceClient) (err error) {
fileLength := distribFile(filePath, client)
if fileLength < 1 {
return errors.New("获取文件信息失败")
}
go cleanCache()
log.Printf("正在下载: %s (%s)", filePath, humanize.IBytes(uint64(fileLength)))
previous := time.Now()
for key, meta := range downloadContext.fileNames {
if checkBlockStat(key, meta) {
continue
}
log.Printf("正在下载:%s区块%d-%d", key, meta.start, meta.end)
group.Add(1)
go startDownloadTask(downloadContext.file.filePath, key, meta, client)
}
processBar(downloadContext.file.length, previous)
group.Wait()
log.Println("区块下载完成,正在合并文件...")
err = createFileOnly(downloadContext.file.filePath)
if err != nil {
log.Println(err.Error())
panic(err)
}
for i := len(downloadContext.tempList) - 1; i >= 0; i-- {
err = appendToFile(downloadContext.file.filePath, readFile(downloadContext.tempList[i]))
if err != nil {
log.Println(err.Error(), "下载失败,请重试")
return
}
if i == 0 {
exit <- true
}
}
flag := <-exit
if flag {
log.Println("合并完成,正在清除临时文件...")
for _, file := range downloadContext.tempList {
deleteFile(file)
}
log.Println("下载完成")
return
}
log.Println("下载失败,请重试")
return
}
func startDownloadTask(filePath string, tempFilePath string, b *block, client transportpb.FileServiceClient) {
existSize := getFileSize(tempFilePath)
req := &transportpb.FileRequest{
FileName: filePath,
FileRangeStart: b.start + existSize,
FileRangeEnd: b.end,
}
stream, err := client.Download(context.Background(), req)
if err != nil {
log.Fatalf("下载异常 : %v\n", err)
}
fp, err := os.OpenFile(tempFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0755)
if err != nil {
log.Fatalf("文件打开异常: %s\n", err)
}
defer fp.Close()
var recvSize int64 = 0
for {
res, err := stream.Recv()
if err == io.EOF {
break
}
block := res.GetBlock()
blockSize := len(block)
recvSize += int64(blockSize)
if blockSize != 0 {
_, err = fp.Write(block)
if err != nil {
log.Fatalf("临时文件保存异常: %s\n", err)
}
}
}
if err != nil || recvSize != (b.end-(b.start+existSize)) {
log.Println("下载重试中")
if b.count > attempt {
pi <- b.id
err = nil
}
if b.count <= attempt {
b.count++
startDownloadTask(filePath, tempFilePath, b, client)
}
}
if err == nil {
group.Done()
}
}
func cleanCache() {
for {
select {
case str := <-pi:
p := filePath(str)
deleteFile(p)
exit <- false
exitSub <- false
case <-exitSub:
break
}
}
}
func processBar(length int64, t time.Time) {
for {
var sum int64 = 0
for key, _ := range downloadContext.fileNames {
sum += getFileSize(key)
}
percent := getPercent(sum, length)
result, _ := strconv.Atoi(percent)
str := percent + "%" + "[" + bar(result, 100) + "] " + " " + fmt.Sprintf("%.f", getCurrentSize(t)) + "s"
fmt.Printf("\r%s", str)
time.Sleep(100 * time.Millisecond)
if sum == length {
fmt.Println("")
break
}
}
}
func bar(count, size int) string {
str := ""
for i := 0; i < size; i++ {
if i < count {
str += "="
} else {
str += " "
}
}
return str
}
func getPercent(a int64, b int64) string {
result := float64(a) / float64(b) * 100
return fmt.Sprintf("%.f", result)
}
func getCurrentSize(t time.Time) float64 {
return time.Now().Sub(t).Seconds()
}
func distribFile(fPath string, client transportpb.FileServiceClient) int64 {
fileNameNotExt, fileName := parseName(fPath)
infoReq := &transportpb.FileInfoRequest{
FileName: fileName,
}
fileInfoResp, err := client.GetFileInfo(context.Background(), infoReq)
if err != nil {
log.Fatalf("获取文件信息异常: %v\n", err)
}
length := fileInfoResp.FileSize
if !checkFileStat(root) {
err := os.MkdirAll(root, 0777)
if err != nil {
panic(err)
}
}
downloadContext.file = &fileInfo{filePath: fPath, fileName: fileName, length: length, savePath: filePath(fileName)}
blocks := chunkFile(length, thread, fileNameNotExt)
distribBlock(blocks)
return length
}
func parseName(fPath string) (tmpName, fullName string) {
u := []byte(fPath)
s := strings.LastIndex(fPath, "/")
if s == -1 {
s = 0
fullName = string(u[s:])
} else {
fullName = string(u[s+1:])
}
t := []byte(fullName)
d := strings.LastIndex(fullName, ".")
if d == -1 {
d = len(t)
tmpName = string(t[:])
} else {
tmpName = string(t[:d])
}
return
}
func distribBlock(b *block) {
if b == nil {
return
}
m := make(map[string]*block)
listId := []string{}
p := filePath(b.id)
m[p] = b
listId = append(listId, p)
for b.previous != nil {
b = b.previous
p = filePath(b.id)
m[p] = b
listId = append(listId, p)
}
downloadContext.fileNames = m
downloadContext.tempList = listId
}
func chunkFile(length int64, thread int64, name string) (b *block) {
blockSize := length / thread
surplus := length % thread
b = nil
var start int64
var i int64
if surplus == 0 {
for i = 1; i <= thread; i++ {
seg := new(block)
r := name + MD5(strconv.FormatInt(i, 10))
seg.id = r
seg.previous = b
seg.start = start
seg.end = blockSize * i
start = blockSize * i
b = seg
}
} else {
for i = 1; i <= thread; i++ {
seg := new(block)
r := name + MD5(strconv.FormatInt(i, 10))
seg.id = r
seg.previous = b
seg.start = start
if i == thread {
seg.end = blockSize*i + surplus
} else {
seg.end = blockSize * i
}
start = blockSize * i
b = seg
}
}
return b
}
// 生成32位MD5
func MD5(text string) string {
ctx := md5.New()
ctx.Write([]byte(text))
return hex.EncodeToString(ctx.Sum(nil))
}
func createFileOnly(file string) error {
if checkFileStat(file) {
deleteFile(file)
}
f, err := os.Create(file)
if err != nil {
log.Println(file, "文件创建失败")
}
defer f.Close()
return err
}
func deleteFile(file string) error {
if !checkFileStat(file) {
return nil
}
err := os.Remove(file)
if err != nil {
log.Println(file, "文件删除失败")
}
return err
}
func checkFileStat(file string) bool {
var exist = true
if _, err := os.Stat(file); os.IsNotExist(err) {
exist = false
}
return exist
}
func checkBlockStat(filePath string, b *block) bool {
m := checkFileStat(filePath)
if m {
if int64(len(readFile(filePath))) == (b.end - b.start) {
return true
} else {
// deleteFile(filePath)
return false
}
}
return false
}
func appendToFile(fileName string, content []byte) error {
// 以只写的模式,打开文件
f, err := os.OpenFile(fileName, os.O_WRONLY, 0644)
if err != nil {
log.Println("file append failed. err: " + err.Error())
} else {
// 查找文件末尾的偏移量
n, _ := f.Seek(0, os.SEEK_END)
// 从末尾的偏移量开始写入内容
_, err = f.WriteAt([]byte(content), n)
}
defer f.Close()
return err
}
func readFile(path string) []byte {
fi, err := os.Open(path)
if err != nil {
panic(err)
}
defer fi.Close()
fd, _ := ioutil.ReadAll(fi)
return fd
}
func filePath(id string) string {
var file string
file = filepath.Join(root, id)
return file
}
func getFileSize(file string) int64 {
if !checkFileStat(file) {
return 0
}
fi, _ := os.Stat(file)
return fi.Size()
}