Joonas' Note

Joonas' Note

[딥러닝 일지] 오프라인에서 파이토치 모델 불러오기 본문

AI/딥러닝

[딥러닝 일지] 오프라인에서 파이토치 모델 불러오기

2022. 3. 29. 22:35 joonas

    이전 글 - [딥러닝 일지] Conv2d 알아보기


    오류 메시지

    VGG 같은 모델을 사용하기 위해 허브에서 불러올 때 아래처럼 연결되지 않는 경우가 있다.

    import torchvision
    model = torchvision.models.vgg16_bn(pretrained=True)
    Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
    ---------------------------------------------------------------------------
    gaierror                                  Traceback (most recent call last)
    /opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
       1349                 h.request(req.get_method(), req.selector, req.data, headers,
    -> 1350                           encode_chunked=req.has_header('Transfer-encoding'))
       1351             except OSError as err: # timeout error
    
    /opt/conda/lib/python3.7/http/client.py in request(self, method, url, body, headers, encode_chunked)
       1280         """Send a complete request to the server."""
    -> 1281         self._send_request(method, url, body, headers, encode_chunked)
       1282 
    
    /opt/conda/lib/python3.7/http/client.py in _send_request(self, method, url, body, headers, encode_chunked)
       1326             body = _encode(body, 'body')
    -> 1327         self.endheaders(body, encode_chunked=encode_chunked)
       1328 
    
    /opt/conda/lib/python3.7/http/client.py in endheaders(self, message_body, encode_chunked)
       1275             raise CannotSendHeader()
    -> 1276         self._send_output(message_body, encode_chunked=encode_chunked)
       1277 
    
    /opt/conda/lib/python3.7/http/client.py in _send_output(self, message_body, encode_chunked)
       1035         del self._buffer[:]
    -> 1036         self.send(msg)
       1037 
    
    /opt/conda/lib/python3.7/http/client.py in send(self, data)
        975             if self.auto_open:
    --> 976                 self.connect()
        977             else:
    
    /opt/conda/lib/python3.7/http/client.py in connect(self)
       1442 
    -> 1443             super().connect()
       1444 
    
    /opt/conda/lib/python3.7/http/client.py in connect(self)
        947         self.sock = self._create_connection(
    --> 948             (self.host,self.port), self.timeout, self.source_address)
        949         self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
    
    /opt/conda/lib/python3.7/socket.py in create_connection(address, timeout, source_address)
        706     err = None
    --> 707     for res in getaddrinfo(host, port, 0, SOCK_STREAM):
        708         af, socktype, proto, canonname, sa = res
    
    /opt/conda/lib/python3.7/socket.py in getaddrinfo(host, port, family, type, proto, flags)
        751     addrlist = []
    --> 752     for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
        753         af, socktype, proto, canonname, sa = res
    
    gaierror: [Errno -3] Temporary failure in name resolution
    
    During handling of the above exception, another exception occurred:
    
    URLError                                  Traceback (most recent call last)
    /tmp/ipykernel_33/2336501538.py in <module>
          1 import collections
          2 
    ----> 3 model = torchvision.models.vgg16_bn(pretrained=True)
          4 model.to(device)
    
    /opt/conda/lib/python3.7/site-packages/torchvision/models/vgg.py in vgg16_bn(pretrained, progress, **kwargs)
        166         progress (bool): If True, displays a progress bar of the download to stderr
        167     """
    --> 168     return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
        169 
        170 
    
    /opt/conda/lib/python3.7/site-packages/torchvision/models/vgg.py in _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs)
         98     if pretrained:
         99         state_dict = load_state_dict_from_url(model_urls[arch],
    --> 100                                               progress=progress)
        101         model.load_state_dict(state_dict)
        102     return model
    
    /opt/conda/lib/python3.7/site-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name)
        569             r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
        570             hash_prefix = r.group(1) if r else None
    --> 571         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
        572 
        573     if _is_legacy_zip_format(cached_file):
    
    /opt/conda/lib/python3.7/site-packages/torch/hub.py in download_url_to_file(url, dst, hash_prefix, progress)
        435     # certificates in older Python
        436     req = Request(url, headers={"User-Agent": "torch.hub"})
    --> 437     u = urlopen(req)
        438     meta = u.info()
        439     if hasattr(meta, 'getheaders'):
    
    /opt/conda/lib/python3.7/urllib/request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
        220     else:
        221         opener = _opener
    --> 222     return opener.open(url, data, timeout)
        223 
        224 def install_opener(opener):
    
    /opt/conda/lib/python3.7/urllib/request.py in open(self, fullurl, data, timeout)
        523             req = meth(req)
        524 
    --> 525         response = self._open(req, data)
        526 
        527         # post-process response
    
    /opt/conda/lib/python3.7/urllib/request.py in _open(self, req, data)
        541         protocol = req.type
        542         result = self._call_chain(self.handle_open, protocol, protocol +
    --> 543                                   '_open', req)
        544         if result:
        545             return result
    
    /opt/conda/lib/python3.7/urllib/request.py in _call_chain(self, chain, kind, meth_name, *args)
        501         for handler in handlers:
        502             func = getattr(handler, meth_name)
    --> 503             result = func(*args)
        504             if result is not None:
        505                 return result
    
    /opt/conda/lib/python3.7/urllib/request.py in https_open(self, req)
       1391         def https_open(self, req):
       1392             return self.do_open(http.client.HTTPSConnection, req,
    -> 1393                 context=self._context, check_hostname=self._check_hostname)
       1394 
       1395         https_request = AbstractHTTPHandler.do_request_
    
    /opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
       1350                           encode_chunked=req.has_header('Transfer-encoding'))
       1351             except OSError as err: # timeout error
    -> 1352                 raise URLError(err)
       1353             r = h.getresponse()
       1354         except:
    
    URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

    해결 방법 1

    Kaggle이나 Colab을 사용중이라면, 인터넷과 연결되지 않은 상태인 지 먼저 확인해본다.

    kaggle 설정

    해결 방법 2

    사내 인트라넷에 세팅된 주피터 노트북과 같은 경우에는, 위와 같이 해결할 수 없을 수도 있다. (실제로 겪음)

    에러 메시지를 보면 "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" 와 같이, 어디서 다운로드를 받아오려 했는 지 알려준다.

    해당 URL에서 pth 파일을 직접 다운로드 받아서, torch.load 처럼 모델을 로드해도 되는데 더 편한 방법이 있다.

    실행하는 스크립트와 같은 위치에, 위에서 다운로드 받은 모델 pth 파일을 그대로 복사하고 아래처럼 설정한다.

    os.environ['TORCH_HOME'] = './'

    이렇게 하면 torchvision.models 모듈에서 모델을 읽을 때, 지정한 위치부터 확인하기 때문에 저장된 모델을 로드한다.

    Comments